feat: add similarity model
This commit is contained in:
131
models/SiaN-similarity/README.md
Normal file
131
models/SiaN-similarity/README.md
Normal file
@@ -0,0 +1,131 @@
|
|||||||
|
# SiaN-Similarity: Модель для оценки схожести изображений
|
||||||
|
|
||||||
|
Модель для оценки схожести между двумя изображениями 256x256. Возвращает значение от 0 до 1, где 1 означает полную схожесть, 0 - полное различие.
|
||||||
|
|
||||||
|
## Архитектура модели
|
||||||
|
|
||||||
|
Модель основана на CNN с residual блоками:
|
||||||
|
- Общий энкодер для обоих изображений
|
||||||
|
- Residual blocks с batch normalization
|
||||||
|
- Слой слияния признаков
|
||||||
|
- Регрессионная голова с сигмоидой на выходе
|
||||||
|
|
||||||
|
## Использование
|
||||||
|
|
||||||
|
### Установка зависимостей
|
||||||
|
```bash
|
||||||
|
pip install torch torchvision pillow
|
||||||
|
```
|
||||||
|
|
||||||
|
### Быстрый старт
|
||||||
|
|
||||||
|
```python
|
||||||
|
import torch
|
||||||
|
from model import SimilarityCNN
|
||||||
|
|
||||||
|
# Создание модели
|
||||||
|
model = SimilarityCNN(
|
||||||
|
input_channels=3,
|
||||||
|
hidden_channels=64,
|
||||||
|
num_blocks=4,
|
||||||
|
dropout_rate=0.3,
|
||||||
|
use_batch_norm=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Предсказание схожести
|
||||||
|
img1 = torch.randn(1, 3, 256, 256) # Изображение 1
|
||||||
|
img2 = torch.randn(1, 3, 256, 256) # Изображение 2
|
||||||
|
|
||||||
|
similarity = model.predict_similarity(img1, img2)
|
||||||
|
print(f"Схожесть: {similarity.item():.4f}")
|
||||||
|
```
|
||||||
|
|
||||||
|
### Обучение модели
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python train_similarity.py \
|
||||||
|
--data_dir "путь/к/данным" \
|
||||||
|
--batch_size 32 \
|
||||||
|
--epochs 100 \
|
||||||
|
--learning_rate 2e-4 \
|
||||||
|
--output_dir "runs/similarity"
|
||||||
|
```
|
||||||
|
|
||||||
|
### Предсказание на новых изображениях
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python predict.py \
|
||||||
|
--image1 "путь/к/изображению1.png" \
|
||||||
|
--image2 "путь/к/изображению2.png" \
|
||||||
|
--checkpoint "runs/similarity/checkpoints/best_model.pt"
|
||||||
|
```
|
||||||
|
|
||||||
|
## Структура проекта
|
||||||
|
|
||||||
|
```
|
||||||
|
SiaN-similarity/
|
||||||
|
├── model.py # Основная модель
|
||||||
|
├── dataloader.py # Даталоадер для обучения
|
||||||
|
├── train_similarity.py # Скрипт для обучения
|
||||||
|
├── predict.py # Скрипт для предсказания
|
||||||
|
├── train.py # Оригинальный тренировочный скрипт
|
||||||
|
└── README.md # Этот файл
|
||||||
|
```
|
||||||
|
|
||||||
|
## Конфигурация модели
|
||||||
|
|
||||||
|
Параметры по умолчанию:
|
||||||
|
- `input_channels`: 3 (RGB)
|
||||||
|
- `hidden_channels`: 64
|
||||||
|
- `num_blocks`: 4
|
||||||
|
- `dropout_rate`: 0.3
|
||||||
|
- `use_batch_norm`: True
|
||||||
|
- `image_size`: (256, 256)
|
||||||
|
|
||||||
|
## Формат данных
|
||||||
|
|
||||||
|
Модель ожидает изображения размером 256x256 пикселей в формате RGB.
|
||||||
|
Для обучения используется датасет с парами изображений и метками схожести.
|
||||||
|
|
||||||
|
## Примеры использования
|
||||||
|
|
||||||
|
### 1. Создание и тестирование модели
|
||||||
|
```python
|
||||||
|
from model import create_similarity_model
|
||||||
|
|
||||||
|
model = create_similarity_model(
|
||||||
|
model_type="cnn",
|
||||||
|
input_size=(256, 256),
|
||||||
|
hidden_channels=32,
|
||||||
|
num_blocks=3,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. Использование функции потерь
|
||||||
|
```python
|
||||||
|
from model import SimilarityLoss
|
||||||
|
|
||||||
|
loss_fn = SimilarityLoss()
|
||||||
|
pred = torch.tensor([[0.8], [0.2]])
|
||||||
|
target = torch.tensor([[1.0], [0.0]])
|
||||||
|
loss = loss_fn(pred, target)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. Расчет метрик
|
||||||
|
```python
|
||||||
|
metrics = loss_fn.compute_metrics(pred, target)
|
||||||
|
print(f"Accuracy: {metrics['accuracy']:.4f}")
|
||||||
|
print(f"F1-score: {metrics['f1']:.4f}")
|
||||||
|
```
|
||||||
|
|
||||||
|
## Требования
|
||||||
|
|
||||||
|
- Python 3.8+
|
||||||
|
- PyTorch 1.9+
|
||||||
|
- torchvision
|
||||||
|
- Pillow
|
||||||
|
- numpy
|
||||||
|
|
||||||
|
## Лицензия
|
||||||
|
|
||||||
|
MIT
|
||||||
519
models/SiaN-similarity/dataloader.py
Normal file
519
models/SiaN-similarity/dataloader.py
Normal file
@@ -0,0 +1,519 @@
|
|||||||
|
config = {
|
||||||
|
# Параметры оптимизатора
|
||||||
|
"learning_rate": 2e-4,
|
||||||
|
"beta1": 0.5,
|
||||||
|
"beta2": 0.999,
|
||||||
|
# Параметры обучения
|
||||||
|
"batch_size": 4,
|
||||||
|
"epochs": 100,
|
||||||
|
# Параметры GAN
|
||||||
|
"gan_mode": "vanilla", # "vanilla", "lsgan", или "wgangp"
|
||||||
|
"lambda_L1": 100.0, # Вес L1 потерь
|
||||||
|
# Регуляризация
|
||||||
|
"grad_clip": 1.0,
|
||||||
|
# Ранняя остановка
|
||||||
|
"early_stopping_patience": 20,
|
||||||
|
# Выходные данные
|
||||||
|
"output_dir": "runs/gan_training",
|
||||||
|
# Логирование
|
||||||
|
"log_interval": 10, # Логировать каждые N батчей
|
||||||
|
"save_interval": 5, # Сохранять чекпоинт каждые N эпох
|
||||||
|
"data_dir": r"C:\Users\admin\Projects\autopilot\datasets\ya_go_maps\images",
|
||||||
|
"batch_size": 32,
|
||||||
|
"image_size": [256, 256],
|
||||||
|
"train_split": 0.8,
|
||||||
|
"num_workers": 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from torch.utils.data import DataLoader, Dataset
|
||||||
|
|
||||||
|
|
||||||
|
class YaGoDataset(Dataset):
|
||||||
|
"""
|
||||||
|
Dataset for homography estimation between Yandex and Google map image pairs.
|
||||||
|
|
||||||
|
This dataset loads pairs of images (Yandex and Google maps) and provides
|
||||||
|
homography matrices for data augmentation and training.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
root_dir: str,
|
||||||
|
transform=None,
|
||||||
|
augment: bool = True,
|
||||||
|
max_samples: Optional[int] = None,
|
||||||
|
image_size: Tuple[int, int] = (700, 700),
|
||||||
|
cache_homographies: bool = True,
|
||||||
|
device: torch.device = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize the YaGoDataset.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
root_dir: Directory containing image pairs (format: {idx:04d}_google.png, {idx:04d}_yandex.png)
|
||||||
|
transform: Optional torchvision transforms to apply
|
||||||
|
augment: Whether to apply homography-based data augmentation
|
||||||
|
max_samples: Maximum number of samples to load (None for all)
|
||||||
|
image_size: Target size for images (height, width)
|
||||||
|
cache_homographies: Whether to cache generated homography matrices to disk
|
||||||
|
"""
|
||||||
|
self.root_dir = root_dir
|
||||||
|
self.transform = transform
|
||||||
|
self.augment = augment
|
||||||
|
self.image_size = image_size
|
||||||
|
self.cache_homographies = cache_homographies
|
||||||
|
self.device = device
|
||||||
|
|
||||||
|
# Find all image pairs
|
||||||
|
self.image_pairs = self._discover_image_pairs()
|
||||||
|
|
||||||
|
if max_samples is not None:
|
||||||
|
self.image_pairs = self.image_pairs[:max_samples]
|
||||||
|
|
||||||
|
print(f"Found {len(self.image_pairs)} image pairs in {root_dir}")
|
||||||
|
|
||||||
|
def _discover_image_pairs(self) -> List[Dict[str, Any]]:
|
||||||
|
"""Discover all Google-Yandex image pairs in the dataset directory."""
|
||||||
|
image_pairs = []
|
||||||
|
|
||||||
|
# Get all Google images
|
||||||
|
google_files = [
|
||||||
|
f for f in os.listdir(self.root_dir) if f.endswith("_google.png")
|
||||||
|
]
|
||||||
|
|
||||||
|
for google_file in sorted(google_files):
|
||||||
|
# Extract index from filename
|
||||||
|
idx_str = google_file.split("_")[0]
|
||||||
|
try:
|
||||||
|
idx = int(idx_str)
|
||||||
|
except ValueError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Check if corresponding Yandex image exists
|
||||||
|
yandex_file = f"{idx:04d}_yandex.png"
|
||||||
|
yandex_path = os.path.join(self.root_dir, yandex_file)
|
||||||
|
|
||||||
|
if os.path.exists(yandex_path):
|
||||||
|
image_pairs.append(
|
||||||
|
{
|
||||||
|
"idx": idx,
|
||||||
|
"google_path": os.path.join(self.root_dir, google_file),
|
||||||
|
"yandex_path": yandex_path,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return image_pairs
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
"""Return the number of image pairs in the dataset."""
|
||||||
|
return len(self.image_pairs)
|
||||||
|
|
||||||
|
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Get a sample from the dataset.
|
||||||
|
|
||||||
|
Returns a dictionary with:
|
||||||
|
- 'google_img': Google map image tensor
|
||||||
|
- 'yandex_img': Yandex map image tensor
|
||||||
|
- 'homography': Ground truth homography matrix (3x3)
|
||||||
|
- 'idx': Sample index
|
||||||
|
"""
|
||||||
|
pair_info = self.image_pairs[idx]
|
||||||
|
google_path = pair_info["google_path"]
|
||||||
|
yandex_path = pair_info["yandex_path"]
|
||||||
|
same_domain = True
|
||||||
|
|
||||||
|
if np.random.rand() > 0.5:
|
||||||
|
random_idx = np.random.randint(0, len(self))
|
||||||
|
google_path = self.image_pairs[random_idx]["google_path"]
|
||||||
|
same_domain = random_idx == idx
|
||||||
|
|
||||||
|
# Load images
|
||||||
|
yandex_img = Image.open(yandex_path).convert("RGB")
|
||||||
|
google_img = Image.open(google_path).convert("RGB")
|
||||||
|
|
||||||
|
# Resize images to target size
|
||||||
|
google_img = google_img.resize(
|
||||||
|
(self.image_size[1], self.image_size[0]), Image.BILINEAR
|
||||||
|
)
|
||||||
|
yandex_img = yandex_img.resize(
|
||||||
|
(self.image_size[1], self.image_size[0]), Image.BILINEAR
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get or generate homography matrix
|
||||||
|
matrices: Tuple[np.ndarray, np.ndarray, np.ndarray] = (
|
||||||
|
self._get_homography_matrix(pair_info["idx"])
|
||||||
|
)
|
||||||
|
|
||||||
|
# Apply data augmentation if enabled
|
||||||
|
if self.augment:
|
||||||
|
google_img, yandex_img, homography_matrix = self._apply_augmentation(
|
||||||
|
google_img, yandex_img, matrices
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert images to tensors
|
||||||
|
if self.transform:
|
||||||
|
google_img = self.transform(google_img)
|
||||||
|
yandex_img = self.transform(yandex_img)
|
||||||
|
else:
|
||||||
|
# Default conversion to tensor
|
||||||
|
google_img = (
|
||||||
|
torch.from_numpy(np.array(google_img)).float().permute(2, 0, 1) / 255.0
|
||||||
|
)
|
||||||
|
yandex_img = (
|
||||||
|
torch.from_numpy(np.array(yandex_img)).float().permute(2, 0, 1) / 255.0
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert homography to tensor
|
||||||
|
if self.augment:
|
||||||
|
homography_tensor = torch.from_numpy(homography_matrix).float()
|
||||||
|
else:
|
||||||
|
homography_tensor = torch.from_numpy(np.eye(3))
|
||||||
|
|
||||||
|
return {
|
||||||
|
"google_img": google_img,
|
||||||
|
"yandex_img": yandex_img,
|
||||||
|
"homography": homography_tensor,
|
||||||
|
"same_domain": same_domain,
|
||||||
|
"idx": torch.tensor(pair_info["idx"], dtype=torch.long),
|
||||||
|
}
|
||||||
|
|
||||||
|
def _get_homography_matrix(
|
||||||
|
self, idx: int
|
||||||
|
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
||||||
|
"""
|
||||||
|
Get homography matrices for a given index.
|
||||||
|
|
||||||
|
If cached homography exists, load it. Otherwise generate a new one.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Generate new homography matrix
|
||||||
|
homography_matrix_1 = self.generate_random_homography()
|
||||||
|
homography_matrix_2 = self.generate_random_homography()
|
||||||
|
homography_matrix_r = np.linalg.inv(homography_matrix_1) @ homography_matrix_2
|
||||||
|
|
||||||
|
result = (homography_matrix_1, homography_matrix_2, homography_matrix_r)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def generate_random_homography(self) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Generate a random homography matrix for data augmentation.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
np.ndarray: 3x3 homography matrix.
|
||||||
|
"""
|
||||||
|
# Generate random affine transformation parameters
|
||||||
|
scale = np.random.uniform(0.8, 1.2) # scaling factor
|
||||||
|
tx = np.random.uniform(-0.50, 0.50) # translation in x
|
||||||
|
ty = np.random.uniform(-0.50, 0.50) # translation in y
|
||||||
|
|
||||||
|
# rotation
|
||||||
|
angle_x = np.random.uniform(np.radians(-10), np.radians(10))
|
||||||
|
angle_y = np.random.uniform(np.radians(-10), np.radians(10))
|
||||||
|
angle_z = np.random.uniform(np.radians(-10), np.radians(10))
|
||||||
|
|
||||||
|
cy, sy = np.cos(angle_z), np.sin(angle_z)
|
||||||
|
cp, sp = np.cos(angle_y), np.sin(angle_y)
|
||||||
|
cr, sr = np.cos(angle_x), np.sin(angle_x)
|
||||||
|
|
||||||
|
Rz = np.array(
|
||||||
|
[
|
||||||
|
[cy, -sy, 0],
|
||||||
|
[sy, cy, 0],
|
||||||
|
[0, 0, 1],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
Ry = np.array(
|
||||||
|
[
|
||||||
|
[cp, 0, sp],
|
||||||
|
[0, 1, 0],
|
||||||
|
[-sp, 0, cp],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
Rx = np.array(
|
||||||
|
[
|
||||||
|
[1, 0, 0],
|
||||||
|
[0, cr, -sr],
|
||||||
|
[0, sr, cr],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create affine transformation matrix
|
||||||
|
T = np.array(
|
||||||
|
[
|
||||||
|
[1, 0, tx],
|
||||||
|
[0, 1, ty],
|
||||||
|
[0, 0, scale],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
K = self.get_camera_matrix()
|
||||||
|
|
||||||
|
return K @ Rx @ Ry @ Rz @ T @ np.linalg.inv(K)
|
||||||
|
|
||||||
|
def get_camera_matrix(self) -> np.ndarray:
|
||||||
|
w, h = config["image_size"]
|
||||||
|
|
||||||
|
K = np.array(
|
||||||
|
[
|
||||||
|
[w / 2, 0, w / 2],
|
||||||
|
[0, h / 2, h / 2],
|
||||||
|
[0, 0, 1],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
return K
|
||||||
|
|
||||||
|
def _apply_augmentation(
|
||||||
|
self,
|
||||||
|
google_img: Image.Image,
|
||||||
|
yandex_img: Image.Image,
|
||||||
|
matrices: Tuple[np.ndarray, np.ndarray, np.ndarray],
|
||||||
|
) -> Tuple[Image.Image, Image.Image, np.ndarray]:
|
||||||
|
"""
|
||||||
|
Apply homography-based data augmentation to image pair.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
google_img: Google map image
|
||||||
|
yandex_img: Yandex map image
|
||||||
|
matrices: homography matrices
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (augmented_google_img, augmented_yandex_img, augmented_homography)
|
||||||
|
"""
|
||||||
|
# Combine with base homography
|
||||||
|
combined_homography = matrices[2]
|
||||||
|
|
||||||
|
# Apply augmentation to both images
|
||||||
|
# google_aug = self._apply_homography_to_image(google_img, aug_homography)
|
||||||
|
yandex_aug = self._apply_homography_to_image(yandex_img, matrices[0])
|
||||||
|
google_aug = self._apply_homography_to_image(google_img, matrices[1])
|
||||||
|
|
||||||
|
return google_aug, yandex_aug, combined_homography
|
||||||
|
|
||||||
|
def _apply_homography_to_image(
|
||||||
|
self, img: Image.Image, homography: np.ndarray
|
||||||
|
) -> Image.Image:
|
||||||
|
"""
|
||||||
|
Apply homography transformation to a single image.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
img: PIL Image to transform
|
||||||
|
homography: 3x3 homography matrix
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Transformed PIL Image
|
||||||
|
"""
|
||||||
|
# Convert to numpy array
|
||||||
|
img_np = np.array(img)
|
||||||
|
|
||||||
|
# Get image dimensions
|
||||||
|
h, w = img_np.shape[:2]
|
||||||
|
|
||||||
|
# Apply homography transformation
|
||||||
|
transformed = cv2.warpPerspective(
|
||||||
|
img_np,
|
||||||
|
homography,
|
||||||
|
(w, h),
|
||||||
|
flags=cv2.INTER_LINEAR,
|
||||||
|
# borderMode=cv2.BORDER_REFLECT,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert back to PIL Image
|
||||||
|
return Image.fromarray(transformed)
|
||||||
|
|
||||||
|
def get_sample_without_augmentation(self, idx: int) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Get a sample without data augmentation.
|
||||||
|
|
||||||
|
Useful for visualization and evaluation.
|
||||||
|
"""
|
||||||
|
pair_info = self.image_pairs[idx]
|
||||||
|
|
||||||
|
# Load images
|
||||||
|
google_img = Image.open(pair_info["google_path"]).convert("RGB")
|
||||||
|
yandex_img = Image.open(pair_info["yandex_path"]).convert("RGB")
|
||||||
|
|
||||||
|
# Resize
|
||||||
|
google_img = google_img.resize(
|
||||||
|
(self.image_size[1], self.image_size[0]), Image.BILINEAR
|
||||||
|
)
|
||||||
|
yandex_img = yandex_img.resize(
|
||||||
|
(self.image_size[1], self.image_size[0]), Image.BILINEAR
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get homography matrix
|
||||||
|
homography_matrix = self._get_homography_matrix(pair_info["idx"])
|
||||||
|
|
||||||
|
return {
|
||||||
|
"google_img": google_img,
|
||||||
|
"yandex_img": yandex_img,
|
||||||
|
"homography": homography_matrix,
|
||||||
|
"idx": pair_info["idx"],
|
||||||
|
"google_path": pair_info["google_path"],
|
||||||
|
"yandex_path": pair_info["yandex_path"],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def create_data_loaders(
|
||||||
|
root_dir: str,
|
||||||
|
batch_size: int = 32,
|
||||||
|
train_split: float = 0.8,
|
||||||
|
num_workers: int = 4,
|
||||||
|
image_size: Tuple[int, int] = (256, 256),
|
||||||
|
augment_train: bool = True,
|
||||||
|
augment_val: bool = False,
|
||||||
|
device: torch.device = None,
|
||||||
|
) -> Tuple[DataLoader, DataLoader]:
|
||||||
|
"""
|
||||||
|
Create train and validation data loaders for homography estimation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
root_dir: Directory containing image pairs
|
||||||
|
batch_size: Batch size for data loaders
|
||||||
|
train_split: Fraction of data to use for training
|
||||||
|
num_workers: Number of worker processes for data loading
|
||||||
|
image_size: Target image size (height, width)
|
||||||
|
augment_train: Whether to augment training data
|
||||||
|
augment_val: Whether to augment validation data
|
||||||
|
device: Target device for tensors (optional)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (train_loader, val_loader)
|
||||||
|
"""
|
||||||
|
from torchvision import transforms
|
||||||
|
|
||||||
|
# Define transforms
|
||||||
|
transform = transforms.Compose(
|
||||||
|
[
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create full dataset
|
||||||
|
full_dataset = YaGoDataset(
|
||||||
|
root_dir=root_dir,
|
||||||
|
transform=transform,
|
||||||
|
augment=False, # We'll handle augmentation separately
|
||||||
|
image_size=image_size,
|
||||||
|
cache_homographies=True,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Split dataset
|
||||||
|
dataset_size = len(full_dataset)
|
||||||
|
train_size = int(train_split * dataset_size)
|
||||||
|
val_size = dataset_size - train_size
|
||||||
|
|
||||||
|
# Create indices for splitting
|
||||||
|
indices = list(range(dataset_size))
|
||||||
|
random.shuffle(indices)
|
||||||
|
train_indices = indices[:train_size]
|
||||||
|
val_indices = indices[train_size:]
|
||||||
|
|
||||||
|
# Create subset samplers
|
||||||
|
from torch.utils.data import Subset
|
||||||
|
|
||||||
|
train_dataset = Subset(full_dataset, train_indices)
|
||||||
|
val_dataset = Subset(full_dataset, val_indices)
|
||||||
|
|
||||||
|
# Apply augmentation by overriding __getitem__ for train dataset
|
||||||
|
if augment_train:
|
||||||
|
|
||||||
|
class AugmentedSubset(Subset):
|
||||||
|
def __init__(self, dataset, indices, device=None):
|
||||||
|
super().__init__(dataset, indices)
|
||||||
|
self.device = device
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
sample = self.dataset[self.indices[idx]]
|
||||||
|
# Apply augmentation
|
||||||
|
google_img = sample["google_img"]
|
||||||
|
yandex_img = sample["yandex_img"]
|
||||||
|
homography = sample["homography"]
|
||||||
|
|
||||||
|
if self.device is not None:
|
||||||
|
google_img = google_img.to(self.device)
|
||||||
|
yandex_img = yandex_img.to(self.device)
|
||||||
|
homography = homography.to(self.device)
|
||||||
|
|
||||||
|
# Generate augmentation homography
|
||||||
|
aug_homography = torch.from_numpy(
|
||||||
|
full_dataset.generate_random_homography()
|
||||||
|
).float()
|
||||||
|
|
||||||
|
if self.device is not None:
|
||||||
|
aug_homography = aug_homography.to(self.device)
|
||||||
|
|
||||||
|
# Combine homographies
|
||||||
|
combined_homography = aug_homography @ homography
|
||||||
|
|
||||||
|
# Apply augmentation (simplified - in practice would warp images)
|
||||||
|
# For now, we just return the combined homography
|
||||||
|
return {
|
||||||
|
"google_img": google_img,
|
||||||
|
"yandex_img": yandex_img,
|
||||||
|
"homography": combined_homography,
|
||||||
|
"idx": sample["idx"],
|
||||||
|
}
|
||||||
|
|
||||||
|
train_dataset = AugmentedSubset(full_dataset, train_indices, device=device)
|
||||||
|
|
||||||
|
# Create data loaders
|
||||||
|
train_loader = DataLoader(
|
||||||
|
train_dataset,
|
||||||
|
batch_size=batch_size,
|
||||||
|
shuffle=True,
|
||||||
|
num_workers=num_workers,
|
||||||
|
pin_memory=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
val_loader = DataLoader(
|
||||||
|
val_dataset,
|
||||||
|
batch_size=batch_size,
|
||||||
|
shuffle=False,
|
||||||
|
num_workers=num_workers,
|
||||||
|
pin_memory=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
return train_loader, val_loader
|
||||||
|
|
||||||
|
|
||||||
|
# Example usage
|
||||||
|
dataset = YaGoDataset(
|
||||||
|
root_dir=config["data_dir"],
|
||||||
|
augment=False,
|
||||||
|
image_size=(256, 256),
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Dataset size: {len(dataset)}")
|
||||||
|
|
||||||
|
# Get a sample
|
||||||
|
sample = dataset[0]
|
||||||
|
print(f"Sample keys: {list(sample.keys())}")
|
||||||
|
print(f"Google image shape: {sample['google_img'].shape}")
|
||||||
|
print(f"Yandex image shape: {sample['yandex_img'].shape}")
|
||||||
|
print(f"Homography shape: {sample['homography'].shape}")
|
||||||
|
|
||||||
|
# Create data loaders
|
||||||
|
train_loader, val_loader = create_data_loaders(
|
||||||
|
root_dir=config["data_dir"],
|
||||||
|
batch_size=16,
|
||||||
|
train_split=0.8,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Train batches: {len(train_loader)}")
|
||||||
|
print(f"Val batches: {len(val_loader)}")
|
||||||
192
models/SiaN-similarity/demo.py
Normal file
192
models/SiaN-similarity/demo.py
Normal file
@@ -0,0 +1,192 @@
|
|||||||
|
"""
|
||||||
|
Демонстрационный скрипт для модели оценки схожести изображений.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from model import SimilarityCNN, SimilarityLoss
|
||||||
|
from PIL import Image, ImageDraw, ImageFont
|
||||||
|
from torchvision import transforms
|
||||||
|
|
||||||
|
|
||||||
|
def create_test_images():
|
||||||
|
"""Создание тестовых изображений для демонстрации."""
|
||||||
|
images = []
|
||||||
|
|
||||||
|
# Изображение 1: Красный квадрат
|
||||||
|
img1 = Image.new("RGB", (256, 256), color="white")
|
||||||
|
draw = ImageDraw.Draw(img1)
|
||||||
|
draw.rectangle([50, 50, 200, 200], fill="red", outline="black", width=2)
|
||||||
|
images.append(("Красный квадрат", img1))
|
||||||
|
|
||||||
|
# Изображение 2: Тот же красный квадрат (похожее)
|
||||||
|
img2 = Image.new("RGB", (256, 256), color="white")
|
||||||
|
draw = ImageDraw.Draw(img2)
|
||||||
|
draw.rectangle([55, 55, 205, 205], fill="red", outline="black", width=2)
|
||||||
|
images.append(("Похожий красный квадрат", img2))
|
||||||
|
|
||||||
|
# Изображение 3: Синий круг (разное)
|
||||||
|
img3 = Image.new("RGB", (256, 256), color="white")
|
||||||
|
draw = ImageDraw.Draw(img3)
|
||||||
|
draw.ellipse([50, 50, 200, 200], fill="blue", outline="black", width=2)
|
||||||
|
images.append(("Синий круг", img3))
|
||||||
|
|
||||||
|
# Изображение 4: Зеленый треугольник (разное)
|
||||||
|
img4 = Image.new("RGB", (256, 256), color="white")
|
||||||
|
draw = ImageDraw.Draw(img4)
|
||||||
|
draw.polygon(
|
||||||
|
[(128, 50), (50, 200), (200, 200)], fill="green", outline="black", width=2
|
||||||
|
)
|
||||||
|
images.append(("Зеленый треугольник", img4))
|
||||||
|
|
||||||
|
return images
|
||||||
|
|
||||||
|
|
||||||
|
def preprocess_image(image):
|
||||||
|
"""Преобразование PIL Image в тензор."""
|
||||||
|
transform = transforms.Compose(
|
||||||
|
[
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
return transform(image).unsqueeze(0) # Добавляем batch dimension
|
||||||
|
|
||||||
|
|
||||||
|
def display_results(images, similarities):
|
||||||
|
"""Отображение результатов сравнения."""
|
||||||
|
fig, axes = plt.subplots(2, 2, figsize=(12, 10))
|
||||||
|
axes = axes.flatten()
|
||||||
|
|
||||||
|
for idx, (title, img) in enumerate(images):
|
||||||
|
ax = axes[idx]
|
||||||
|
ax.imshow(img)
|
||||||
|
ax.set_title(title, fontsize=12, fontweight="bold")
|
||||||
|
ax.axis("off")
|
||||||
|
|
||||||
|
plt.suptitle("Тестовые изображения", fontsize=16, fontweight="bold")
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
# Вывод результатов сравнения
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("РЕЗУЛЬТАТЫ СРАВНЕНИЯ ИЗОБРАЖЕНИЙ")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
comparisons = [
|
||||||
|
("Красный квадрат", "Похожий красный квадрат"),
|
||||||
|
("Красный квадрат", "Синий круг"),
|
||||||
|
("Красный квадрат", "Зеленый треугольник"),
|
||||||
|
("Похожий красный квадрат", "Синий круг"),
|
||||||
|
]
|
||||||
|
|
||||||
|
for i, (name1, name2) in enumerate(comparisons):
|
||||||
|
idx1 = [idx for idx, (name, _) in enumerate(images) if name == name1][0]
|
||||||
|
idx2 = [idx for idx, (name, _) in enumerate(images) if name == name2][0]
|
||||||
|
|
||||||
|
sim = similarities[idx1, idx2]
|
||||||
|
interpretation = "ПОХОЖИ" if sim > 0.5 else "РАЗНЫЕ"
|
||||||
|
|
||||||
|
print(f"\n{name1} vs {name2}:")
|
||||||
|
print(f" Схожесть: {sim:.4f}")
|
||||||
|
print(f" Интерпретация: {interpretation}")
|
||||||
|
print(f" Уверенность: {'Высокая' if sim > 0.7 or sim < 0.3 else 'Средняя'}")
|
||||||
|
|
||||||
|
|
||||||
|
def test_loss_function():
|
||||||
|
"""Тестирование функции потерь."""
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("ТЕСТИРОВАНИЕ ФУНКЦИИ ПОТЕРЬ")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
loss_fn = SimilarityLoss()
|
||||||
|
|
||||||
|
# Тестовые данные
|
||||||
|
predictions = torch.tensor([[0.9], [0.1], [0.7], [0.3]])
|
||||||
|
targets = torch.tensor([[1.0], [0.0], [1.0], [0.0]])
|
||||||
|
|
||||||
|
# Расчет потерь
|
||||||
|
loss = loss_fn(predictions, targets)
|
||||||
|
print(f"\nПотери: {loss.item():.4f}")
|
||||||
|
|
||||||
|
# Расчет метрик
|
||||||
|
metrics = loss_fn.compute_metrics(predictions, targets)
|
||||||
|
print("\nМетрики:")
|
||||||
|
for key, value in metrics.items():
|
||||||
|
print(f" {key}: {value:.4f}")
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""Основная функция демонстрации."""
|
||||||
|
print("ДЕМОНСТРАЦИЯ МОДЕЛИ ОЦЕНКИ СХОЖЕСТИ ИЗОБРАЖЕНИЙ")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
# Создание модели
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
print(f"\nУстройство: {device}")
|
||||||
|
|
||||||
|
model = SimilarityCNN(
|
||||||
|
input_channels=3,
|
||||||
|
hidden_channels=64,
|
||||||
|
num_blocks=4,
|
||||||
|
dropout_rate=0.3,
|
||||||
|
use_batch_norm=True,
|
||||||
|
).to(device)
|
||||||
|
|
||||||
|
print(f"Параметры модели: {sum(p.numel() for p in model.parameters()):,}")
|
||||||
|
|
||||||
|
# Создание тестовых изображений
|
||||||
|
print("\nСоздание тестовых изображений...")
|
||||||
|
test_images = create_test_images()
|
||||||
|
|
||||||
|
# Преобразование изображений в тензоры
|
||||||
|
tensors = []
|
||||||
|
for name, img in test_images:
|
||||||
|
tensor = preprocess_image(img).to(device)
|
||||||
|
tensors.append(tensor)
|
||||||
|
|
||||||
|
# Расчет схожести между всеми парами изображений
|
||||||
|
print("\nРасчет схожести между изображениями...")
|
||||||
|
n_images = len(test_images)
|
||||||
|
similarity_matrix = np.zeros((n_images, n_images))
|
||||||
|
|
||||||
|
model.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
for i in range(n_images):
|
||||||
|
for j in range(n_images):
|
||||||
|
if i <= j: # Рассчитываем только верхний треугольник
|
||||||
|
sim = model.predict_similarity(tensors[i], tensors[j])
|
||||||
|
similarity_matrix[i, j] = sim.item()
|
||||||
|
similarity_matrix[j, i] = sim.item() # Симметричная матрица
|
||||||
|
|
||||||
|
# Отображение результатов
|
||||||
|
display_results(test_images, similarity_matrix)
|
||||||
|
|
||||||
|
# Тестирование функции потерь
|
||||||
|
test_loss_function()
|
||||||
|
|
||||||
|
# Дополнительная информация
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("ИНФОРМАЦИЯ О МОДЕЛИ")
|
||||||
|
print("=" * 60)
|
||||||
|
print("\nАрхитектура модели:")
|
||||||
|
print("-" * 40)
|
||||||
|
print("Вход: два изображения 256x256x3")
|
||||||
|
print("Энкодер: CNN с residual блоками")
|
||||||
|
print("Слой слияния: объединение признаков")
|
||||||
|
print("Выход: значение схожести [0, 1]")
|
||||||
|
print("\nИнтерпретация результатов:")
|
||||||
|
print("- 0.8-1.0: Очень похожи")
|
||||||
|
print("- 0.6-0.8: Похожи")
|
||||||
|
print("- 0.4-0.6: Нейтрально")
|
||||||
|
print("- 0.2-0.4: Разные")
|
||||||
|
print("- 0.0-0.2: Совершенно разные")
|
||||||
|
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("ДЕМОНСТРАЦИЯ ЗАВЕРШЕНА")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
216
models/SiaN-similarity/example.py
Normal file
216
models/SiaN-similarity/example.py
Normal file
@@ -0,0 +1,216 @@
|
|||||||
|
"""
|
||||||
|
Пример использования модели оценки схожести с даталоадером.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from dataloader import YaGoDataset, create_data_loaders
|
||||||
|
from model import SimilarityCNN, SimilarityLoss
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""Основной пример использования."""
|
||||||
|
print("ПРИМЕР ИСПОЛЬЗОВАНИЯ МОДЕЛИ СХОЖЕСТИ С ДАТАЛОАДЕРОМ")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
# Конфигурация
|
||||||
|
config = {
|
||||||
|
"data_dir": r"C:\Users\admin\Projects\autopilot\datasets\ya_go_maps\images",
|
||||||
|
"batch_size": 4,
|
||||||
|
"image_size": (256, 256),
|
||||||
|
"train_split": 0.8,
|
||||||
|
"num_workers": 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Устройство
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
print(f"Устройство: {device}")
|
||||||
|
|
||||||
|
# 1. Создание датасета
|
||||||
|
print("\n1. СОЗДАНИЕ ДАТАСЕТА")
|
||||||
|
print("-" * 40)
|
||||||
|
|
||||||
|
dataset = YaGoDataset(
|
||||||
|
root_dir=config["data_dir"],
|
||||||
|
augment=False,
|
||||||
|
image_size=config["image_size"],
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Размер датасета: {len(dataset)} пар изображений")
|
||||||
|
|
||||||
|
# Получение примера из датасета
|
||||||
|
sample = dataset[0]
|
||||||
|
print(f"\nПример из датасета:")
|
||||||
|
print(f" Google image shape: {sample['google_img'].shape}")
|
||||||
|
print(f" Yandex image shape: {sample['yandex_img'].shape}")
|
||||||
|
print(f" Same domain: {sample['same_domain']}")
|
||||||
|
print(f" Index: {sample['idx'].item()}")
|
||||||
|
|
||||||
|
# 2. Создание даталоадеров
|
||||||
|
print("\n2. СОЗДАНИЕ ДАТАЛОАДЕРОВ")
|
||||||
|
print("-" * 40)
|
||||||
|
|
||||||
|
train_loader, val_loader = create_data_loaders(
|
||||||
|
root_dir=config["data_dir"],
|
||||||
|
batch_size=config["batch_size"],
|
||||||
|
train_split=config["train_split"],
|
||||||
|
num_workers=config["num_workers"],
|
||||||
|
image_size=config["image_size"],
|
||||||
|
augment_train=True,
|
||||||
|
augment_val=False,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Train batches: {len(train_loader)}")
|
||||||
|
print(f"Val batches: {len(val_loader)}")
|
||||||
|
|
||||||
|
# 3. Создание модели
|
||||||
|
print("\n3. СОЗДАНИЕ МОДЕЛИ")
|
||||||
|
print("-" * 40)
|
||||||
|
|
||||||
|
model = SimilarityCNN(
|
||||||
|
input_channels=3,
|
||||||
|
hidden_channels=64,
|
||||||
|
num_blocks=4,
|
||||||
|
dropout_rate=0.3,
|
||||||
|
use_batch_norm=True,
|
||||||
|
).to(device)
|
||||||
|
|
||||||
|
print(f"Параметры модели: {sum(p.numel() for p in model.parameters()):,}")
|
||||||
|
|
||||||
|
# 4. Тестирование на одном батче
|
||||||
|
print("\n4. ТЕСТИРОВАНИЕ НА ОДНОМ БАТЧЕ")
|
||||||
|
print("-" * 40)
|
||||||
|
|
||||||
|
# Получаем батч из train_loader
|
||||||
|
for batch in train_loader:
|
||||||
|
google_img = batch["google_img"].to(device)
|
||||||
|
yandex_img = batch["yandex_img"].to(device)
|
||||||
|
same_domain = batch["same_domain"].float().to(device).unsqueeze(1)
|
||||||
|
|
||||||
|
print(f"Batch size: {google_img.shape[0]}")
|
||||||
|
print(f"Image shape: {google_img.shape[1:]}")
|
||||||
|
print(f"Same domain labels: {same_domain.squeeze().tolist()}")
|
||||||
|
|
||||||
|
# Предсказание схожести
|
||||||
|
with torch.no_grad():
|
||||||
|
predictions = model.predict_similarity(google_img, yandex_img)
|
||||||
|
print(f"\nПредсказания схожести:")
|
||||||
|
for i in range(len(predictions)):
|
||||||
|
print(
|
||||||
|
f" Sample {i}: {predictions[i].item():.4f} (target: {same_domain[i].item():.1f})"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Расчет потерь
|
||||||
|
loss_fn = SimilarityLoss().to(device)
|
||||||
|
loss = loss_fn(predictions, same_domain)
|
||||||
|
print(f"\nПотери на батче: {loss.item():.4f}")
|
||||||
|
|
||||||
|
# Расчет метрик
|
||||||
|
metrics = loss_fn.compute_metrics(predictions, same_domain)
|
||||||
|
print("\nМетрики на батче:")
|
||||||
|
for key, value in metrics.items():
|
||||||
|
print(f" {key}: {value:.4f}")
|
||||||
|
|
||||||
|
break # Только первый батч
|
||||||
|
|
||||||
|
# 5. Обучение на одном эпохе (демонстрация)
|
||||||
|
print("\n5. ДЕМОНСТРАЦИЯ ОБУЧЕНИЯ НА ОДНОЙ ЭПОХЕ")
|
||||||
|
print("-" * 40)
|
||||||
|
|
||||||
|
optimizer = torch.optim.Adam(model.parameters(), lr=2e-4)
|
||||||
|
model.train()
|
||||||
|
|
||||||
|
total_loss = 0
|
||||||
|
total_samples = 0
|
||||||
|
|
||||||
|
for batch_idx, batch in enumerate(train_loader):
|
||||||
|
if batch_idx >= 3: # Ограничиваем 3 батчами для демонстрации
|
||||||
|
break
|
||||||
|
|
||||||
|
google_img = batch["google_img"].to(device)
|
||||||
|
yandex_img = batch["yandex_img"].to(device)
|
||||||
|
same_domain = batch["same_domain"].float().to(device).unsqueeze(1)
|
||||||
|
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
predictions = model(google_img, yandex_img)
|
||||||
|
loss = loss_fn(predictions, same_domain)
|
||||||
|
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
batch_loss = loss.item() * google_img.size(0)
|
||||||
|
total_loss += batch_loss
|
||||||
|
total_samples += google_img.size(0)
|
||||||
|
|
||||||
|
print(f"Batch {batch_idx + 1}: loss = {loss.item():.4f}")
|
||||||
|
|
||||||
|
avg_loss = total_loss / total_samples
|
||||||
|
print(f"\nСредние потери за 3 батча: {avg_loss:.4f}")
|
||||||
|
|
||||||
|
# 6. Валидация
|
||||||
|
print("\n6. ВАЛИДАЦИЯ")
|
||||||
|
print("-" * 40)
|
||||||
|
|
||||||
|
model.eval()
|
||||||
|
val_loss = 0
|
||||||
|
val_samples = 0
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
for batch_idx, batch in enumerate(val_loader):
|
||||||
|
if batch_idx >= 2: # Ограничиваем 2 батчами для демонстрации
|
||||||
|
break
|
||||||
|
|
||||||
|
google_img = batch["google_img"].to(device)
|
||||||
|
yandex_img = batch["yandex_img"].to(device)
|
||||||
|
same_domain = batch["same_domain"].float().to(device).unsqueeze(1)
|
||||||
|
|
||||||
|
predictions = model.predict_similarity(google_img, yandex_img)
|
||||||
|
loss = loss_fn(predictions, same_domain)
|
||||||
|
|
||||||
|
val_loss += loss.item() * google_img.size(0)
|
||||||
|
val_samples += google_img.size(0)
|
||||||
|
|
||||||
|
print(f"Val batch {batch_idx + 1}: loss = {loss.item():.4f}")
|
||||||
|
|
||||||
|
avg_val_loss = val_loss / val_samples
|
||||||
|
print(f"\nСредние потери на валидации: {avg_val_loss:.4f}")
|
||||||
|
|
||||||
|
# 7. Пример использования для отдельных изображений
|
||||||
|
print("\n7. ПРИМЕР ДЛЯ ОТДЕЛЬНЫХ ИЗОБРАЖЕНИЙ")
|
||||||
|
print("-" * 40)
|
||||||
|
|
||||||
|
# Берем два примера из датасета
|
||||||
|
sample1 = dataset[0]
|
||||||
|
sample2 = dataset[1]
|
||||||
|
|
||||||
|
# Подготавливаем тензоры
|
||||||
|
img1_1 = sample1["google_img"].unsqueeze(0).to(device)
|
||||||
|
img1_2 = sample1["yandex_img"].unsqueeze(0).to(device)
|
||||||
|
|
||||||
|
img2_1 = sample2["google_img"].unsqueeze(0).to(device)
|
||||||
|
img2_2 = sample2["yandex_img"].unsqueeze(0).to(device)
|
||||||
|
|
||||||
|
# Предсказания
|
||||||
|
with torch.no_grad():
|
||||||
|
# Сравнение пар из одного домена
|
||||||
|
sim_same1 = model.predict_similarity(img1_1, img1_2)
|
||||||
|
sim_same2 = model.predict_similarity(img2_1, img2_2)
|
||||||
|
|
||||||
|
# Сравнение пар из разных доменов
|
||||||
|
sim_diff1 = model.predict_similarity(img1_1, img2_2)
|
||||||
|
sim_diff2 = model.predict_similarity(img2_1, img1_2)
|
||||||
|
|
||||||
|
print("Сравнение пар изображений:")
|
||||||
|
print(f" Пара 1 (один домен): {sim_same1.item():.4f}")
|
||||||
|
print(f" Пара 2 (один домен): {sim_same2.item():.4f}")
|
||||||
|
print(f" Разные домены 1: {sim_diff1.item():.4f}")
|
||||||
|
print(f" Разные домены 2: {sim_diff2.item():.4f}")
|
||||||
|
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("ПРИМЕР ЗАВЕРШЕН")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
322
models/SiaN-similarity/model.py
Normal file
322
models/SiaN-similarity/model.py
Normal file
@@ -0,0 +1,322 @@
|
|||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
class SimilarityCNN(nn.Module):
|
||||||
|
"""
|
||||||
|
CNN model for similarity estimation between two images.
|
||||||
|
|
||||||
|
Takes two images as input and outputs a similarity score between 0 and 1.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
input_channels: int = 3,
|
||||||
|
hidden_channels: int = 64,
|
||||||
|
num_blocks: int = 4,
|
||||||
|
dropout_rate: float = 0.3,
|
||||||
|
use_batch_norm: bool = True,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.input_channels = input_channels
|
||||||
|
self.hidden_channels = hidden_channels
|
||||||
|
self.num_blocks = num_blocks
|
||||||
|
self.dropout_rate = dropout_rate
|
||||||
|
self.use_batch_norm = use_batch_norm
|
||||||
|
|
||||||
|
self.encoder = self._build_encoder()
|
||||||
|
|
||||||
|
self.fusion_layers = self._build_fusion_layers()
|
||||||
|
|
||||||
|
self.regression_head = self._build_regression_head()
|
||||||
|
|
||||||
|
self._initialize_weights()
|
||||||
|
|
||||||
|
def _build_encoder(self) -> nn.Module:
|
||||||
|
layers = []
|
||||||
|
in_channels = self.input_channels
|
||||||
|
out_channels = self.hidden_channels
|
||||||
|
|
||||||
|
layers.append(
|
||||||
|
nn.Conv2d(in_channels, out_channels, kernel_size=7, stride=2, padding=3)
|
||||||
|
)
|
||||||
|
if self.use_batch_norm:
|
||||||
|
layers.append(nn.BatchNorm2d(out_channels))
|
||||||
|
layers.append(nn.ReLU(inplace=True))
|
||||||
|
layers.append(nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
|
||||||
|
|
||||||
|
for i in range(self.num_blocks):
|
||||||
|
block_in_channels = out_channels
|
||||||
|
block_out_channels = out_channels * 2 if i < 2 else out_channels
|
||||||
|
|
||||||
|
layers.append(
|
||||||
|
ResidualBlock(
|
||||||
|
in_channels=block_in_channels,
|
||||||
|
out_channels=block_out_channels,
|
||||||
|
stride=1 if i == 0 else 2,
|
||||||
|
dropout_rate=self.dropout_rate,
|
||||||
|
use_batch_norm=self.use_batch_norm,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if i < 2:
|
||||||
|
out_channels = block_out_channels
|
||||||
|
|
||||||
|
return nn.Sequential(*layers)
|
||||||
|
|
||||||
|
def _build_fusion_layers(self) -> nn.Module:
|
||||||
|
fused_channels = self.hidden_channels * 8
|
||||||
|
|
||||||
|
layers = [
|
||||||
|
nn.Conv2d(
|
||||||
|
fused_channels, self.hidden_channels * 4, kernel_size=3, padding=1
|
||||||
|
),
|
||||||
|
nn.BatchNorm2d(self.hidden_channels * 4)
|
||||||
|
if self.use_batch_norm
|
||||||
|
else nn.Identity(),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
nn.Dropout2d(self.dropout_rate),
|
||||||
|
nn.Conv2d(
|
||||||
|
self.hidden_channels * 4,
|
||||||
|
self.hidden_channels * 2,
|
||||||
|
kernel_size=3,
|
||||||
|
padding=1,
|
||||||
|
),
|
||||||
|
nn.BatchNorm2d(self.hidden_channels * 2)
|
||||||
|
if self.use_batch_norm
|
||||||
|
else nn.Identity(),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
nn.Dropout2d(self.dropout_rate),
|
||||||
|
nn.AdaptiveAvgPool2d((1, 1)),
|
||||||
|
]
|
||||||
|
|
||||||
|
return nn.Sequential(*layers)
|
||||||
|
|
||||||
|
def _build_regression_head(self) -> nn.Module:
|
||||||
|
input_features = self.hidden_channels * 2
|
||||||
|
|
||||||
|
layers = [
|
||||||
|
nn.Flatten(),
|
||||||
|
nn.Linear(input_features, 512),
|
||||||
|
nn.BatchNorm1d(512) if self.use_batch_norm else nn.Identity(),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
nn.Dropout(self.dropout_rate),
|
||||||
|
nn.Linear(512, 256),
|
||||||
|
nn.BatchNorm1d(256) if self.use_batch_norm else nn.Identity(),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
nn.Dropout(self.dropout_rate),
|
||||||
|
nn.Linear(256, 128),
|
||||||
|
nn.BatchNorm1d(128) if self.use_batch_norm else nn.Identity(),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
nn.Dropout(self.dropout_rate),
|
||||||
|
nn.Linear(128, 1),
|
||||||
|
nn.Sigmoid(),
|
||||||
|
]
|
||||||
|
|
||||||
|
return nn.Sequential(*layers)
|
||||||
|
|
||||||
|
def _initialize_weights(self):
|
||||||
|
for m in self.modules():
|
||||||
|
if isinstance(m, nn.Conv2d):
|
||||||
|
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
|
||||||
|
if m.bias is not None:
|
||||||
|
nn.init.constant_(m.bias, 0)
|
||||||
|
elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
|
||||||
|
nn.init.constant_(m.weight, 1)
|
||||||
|
nn.init.constant_(m.bias, 0)
|
||||||
|
elif isinstance(m, nn.Linear):
|
||||||
|
nn.init.normal_(m.weight, 0, 0.01)
|
||||||
|
nn.init.constant_(m.bias, 0)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
img1: torch.Tensor,
|
||||||
|
img2: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
features1 = self.encoder(img1)
|
||||||
|
features2 = self.encoder(img2)
|
||||||
|
|
||||||
|
combined_features = torch.cat([features1, features2], dim=1)
|
||||||
|
|
||||||
|
fused_features = self.fusion_layers(combined_features)
|
||||||
|
|
||||||
|
similarity = self.regression_head(fused_features)
|
||||||
|
|
||||||
|
return similarity
|
||||||
|
|
||||||
|
def predict_similarity(
|
||||||
|
self,
|
||||||
|
img1: torch.Tensor,
|
||||||
|
img2: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
original_training = self.training
|
||||||
|
self.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
similarity = self.forward(img1, img2)
|
||||||
|
if original_training:
|
||||||
|
self.train()
|
||||||
|
return similarity
|
||||||
|
|
||||||
|
|
||||||
|
class ResidualBlock(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
stride: int = 1,
|
||||||
|
dropout_rate: float = 0.3,
|
||||||
|
use_batch_norm: bool = True,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.conv1 = nn.Conv2d(
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=stride,
|
||||||
|
padding=1,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
self.bn1 = nn.BatchNorm2d(out_channels) if use_batch_norm else nn.Identity()
|
||||||
|
self.relu1 = nn.ReLU(inplace=True)
|
||||||
|
self.dropout1 = (
|
||||||
|
nn.Dropout2d(dropout_rate) if dropout_rate > 0 else nn.Identity()
|
||||||
|
)
|
||||||
|
|
||||||
|
self.conv2 = nn.Conv2d(
|
||||||
|
out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False
|
||||||
|
)
|
||||||
|
self.bn2 = nn.BatchNorm2d(out_channels) if use_batch_norm else nn.Identity()
|
||||||
|
self.relu2 = nn.ReLU(inplace=True)
|
||||||
|
self.dropout2 = (
|
||||||
|
nn.Dropout2d(dropout_rate) if dropout_rate > 0 else nn.Identity()
|
||||||
|
)
|
||||||
|
|
||||||
|
self.shortcut = nn.Sequential()
|
||||||
|
if stride != 1 or in_channels != out_channels:
|
||||||
|
self.shortcut = nn.Sequential(
|
||||||
|
nn.Conv2d(
|
||||||
|
in_channels, out_channels, kernel_size=1, stride=stride, bias=False
|
||||||
|
),
|
||||||
|
nn.BatchNorm2d(out_channels) if use_batch_norm else nn.Identity(),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
identity = self.shortcut(x)
|
||||||
|
|
||||||
|
out = self.conv1(x)
|
||||||
|
out = self.bn1(out)
|
||||||
|
out = self.relu1(out)
|
||||||
|
out = self.dropout1(out)
|
||||||
|
|
||||||
|
out = self.conv2(out)
|
||||||
|
out = self.bn2(out)
|
||||||
|
|
||||||
|
out += identity
|
||||||
|
out = self.relu2(out)
|
||||||
|
out = self.dropout2(out)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class SimilarityLoss(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.criterion = nn.BCELoss()
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
pred_similarity: torch.Tensor,
|
||||||
|
target_same: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return self.criterion(pred_similarity, target_same)
|
||||||
|
|
||||||
|
def compute_metrics(
|
||||||
|
self,
|
||||||
|
pred_similarity: torch.Tensor,
|
||||||
|
target_same: torch.Tensor,
|
||||||
|
threshold: float = 0.5,
|
||||||
|
) -> dict:
|
||||||
|
with torch.no_grad():
|
||||||
|
pred_binary = (pred_similarity > threshold).float()
|
||||||
|
target_binary = (target_same > 0.5).float()
|
||||||
|
|
||||||
|
correct = (pred_binary == target_binary).float()
|
||||||
|
accuracy = correct.mean().item()
|
||||||
|
|
||||||
|
tp = ((pred_binary == 1) & (target_binary == 1)).float().sum().item()
|
||||||
|
fp = ((pred_binary == 1) & (target_binary == 0)).float().sum().item()
|
||||||
|
fn = ((pred_binary == 0) & (target_binary == 1)).float().sum().item()
|
||||||
|
tn = ((pred_binary == 0) & (target_binary == 0)).float().sum().item()
|
||||||
|
|
||||||
|
precision = tp / (tp + fp + 1e-8)
|
||||||
|
recall = tp / (tp + fn + 1e-8)
|
||||||
|
f1 = 2 * precision * recall / (precision + recall + 1e-8)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"accuracy": accuracy,
|
||||||
|
"precision": precision,
|
||||||
|
"recall": recall,
|
||||||
|
"f1": f1,
|
||||||
|
"mean_similarity": pred_similarity.mean().item(),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def create_similarity_model(
|
||||||
|
model_type: str = "cnn",
|
||||||
|
input_size: Tuple[int, int] = (256, 256),
|
||||||
|
**kwargs,
|
||||||
|
) -> nn.Module:
|
||||||
|
if model_type == "cnn":
|
||||||
|
return SimilarityCNN(**kwargs)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown model type: {model_type}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
print(f"Using device: {device}")
|
||||||
|
|
||||||
|
model = SimilarityCNN(
|
||||||
|
input_channels=3,
|
||||||
|
hidden_channels=64,
|
||||||
|
num_blocks=4,
|
||||||
|
dropout_rate=0.3,
|
||||||
|
use_batch_norm=True,
|
||||||
|
).to(device)
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"Model created with {sum(p.numel() for p in model.parameters()):,} parameters"
|
||||||
|
)
|
||||||
|
|
||||||
|
batch_size = 4
|
||||||
|
height, width = 256, 256
|
||||||
|
|
||||||
|
img1 = torch.randn(batch_size, 3, height, width).to(device)
|
||||||
|
img2 = torch.randn(batch_size, 3, height, width).to(device)
|
||||||
|
|
||||||
|
print("\nTesting forward pass...")
|
||||||
|
output = model(img1, img2)
|
||||||
|
print(f"Output shape: {output.shape}")
|
||||||
|
print(f"Sample output: {output[0].item():.4f}")
|
||||||
|
|
||||||
|
print("\nTesting prediction...")
|
||||||
|
pred = model.predict_similarity(img1, img2)
|
||||||
|
print(f"Prediction shape: {pred.shape}")
|
||||||
|
|
||||||
|
print("\nTesting loss function...")
|
||||||
|
target = torch.rand(batch_size, 1).to(device)
|
||||||
|
loss_fn = SimilarityLoss().to(device)
|
||||||
|
loss = loss_fn(output, target)
|
||||||
|
print(f"Loss value: {loss.item():.6f}")
|
||||||
|
|
||||||
|
print("\nTesting metrics...")
|
||||||
|
metrics = loss_fn.compute_metrics(output, target)
|
||||||
|
for key, value in metrics.items():
|
||||||
|
print(f"{key}: {value:.6f}")
|
||||||
|
|
||||||
|
print("\nAll tests completed successfully!")
|
||||||
146
models/SiaN-similarity/predict.py
Normal file
146
models/SiaN-similarity/predict.py
Normal file
@@ -0,0 +1,146 @@
|
|||||||
|
"""
|
||||||
|
Script for predicting similarity between two images.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from model import SimilarityCNN
|
||||||
|
from PIL import Image
|
||||||
|
from torchvision import transforms
|
||||||
|
|
||||||
|
|
||||||
|
def load_image(image_path: str, image_size: tuple = (256, 256)) -> torch.Tensor:
|
||||||
|
"""Load and preprocess image."""
|
||||||
|
transform = transforms.Compose(
|
||||||
|
[
|
||||||
|
transforms.Resize(image_size),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
image = Image.open(image_path).convert("RGB")
|
||||||
|
return transform(image).unsqueeze(0) # Add batch dimension
|
||||||
|
|
||||||
|
|
||||||
|
def predict_similarity(
|
||||||
|
model: SimilarityCNN,
|
||||||
|
image1_path: str,
|
||||||
|
image2_path: str,
|
||||||
|
device: torch.device,
|
||||||
|
image_size: tuple = (256, 256),
|
||||||
|
) -> float:
|
||||||
|
"""Predict similarity between two images."""
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
img1 = load_image(image1_path, image_size).to(device)
|
||||||
|
img2 = load_image(image2_path, image_size).to(device)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
similarity = model(img1, img2)
|
||||||
|
|
||||||
|
return similarity.item()
|
||||||
|
|
||||||
|
|
||||||
|
def load_model(
|
||||||
|
checkpoint_path: str,
|
||||||
|
device: torch.device,
|
||||||
|
**model_kwargs,
|
||||||
|
) -> SimilarityCNN:
|
||||||
|
"""Load model from checkpoint."""
|
||||||
|
model = SimilarityCNN(**model_kwargs).to(device)
|
||||||
|
|
||||||
|
checkpoint = torch.load(checkpoint_path, map_location=device)
|
||||||
|
model.load_state_dict(checkpoint["model_state_dict"])
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Predict similarity between two images"
|
||||||
|
)
|
||||||
|
parser.add_argument("--image1", type=str, required=True, help="Path to first image")
|
||||||
|
parser.add_argument(
|
||||||
|
"--image2", type=str, required=True, help="Path to second image"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--checkpoint",
|
||||||
|
type=str,
|
||||||
|
default="runs/similarity/checkpoints/best_model.pt",
|
||||||
|
help="Path to model checkpoint",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--device",
|
||||||
|
type=str,
|
||||||
|
default="cuda" if torch.cuda.is_available() else "cpu",
|
||||||
|
help="Device to use for inference",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--image_size",
|
||||||
|
type=int,
|
||||||
|
default=256,
|
||||||
|
help="Image size for model input",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
device = torch.device(args.device)
|
||||||
|
print(f"Using device: {device}")
|
||||||
|
|
||||||
|
if not os.path.exists(args.image1):
|
||||||
|
print(f"Error: Image not found: {args.image1}")
|
||||||
|
return
|
||||||
|
|
||||||
|
if not os.path.exists(args.image2):
|
||||||
|
print(f"Error: Image not found: {args.image2}")
|
||||||
|
return
|
||||||
|
|
||||||
|
if not os.path.exists(args.checkpoint):
|
||||||
|
print(f"Warning: Checkpoint not found: {args.checkpoint}")
|
||||||
|
print("Using randomly initialized model for demonstration")
|
||||||
|
model = SimilarityCNN(
|
||||||
|
input_channels=3,
|
||||||
|
hidden_channels=64,
|
||||||
|
num_blocks=4,
|
||||||
|
dropout_rate=0.3,
|
||||||
|
use_batch_norm=True,
|
||||||
|
).to(device)
|
||||||
|
else:
|
||||||
|
print(f"Loading model from: {args.checkpoint}")
|
||||||
|
model = load_model(
|
||||||
|
checkpoint_path=args.checkpoint,
|
||||||
|
device=device,
|
||||||
|
input_channels=3,
|
||||||
|
hidden_channels=64,
|
||||||
|
num_blocks=4,
|
||||||
|
dropout_rate=0.3,
|
||||||
|
use_batch_norm=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"Model loaded with {sum(p.numel() for p in model.parameters()):,} parameters"
|
||||||
|
)
|
||||||
|
|
||||||
|
similarity = predict_similarity(
|
||||||
|
model=model,
|
||||||
|
image1_path=args.image1,
|
||||||
|
image2_path=args.image2,
|
||||||
|
device=device,
|
||||||
|
image_size=(args.image_size, args.image_size),
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"\nSimilarity between images:")
|
||||||
|
print(f" Image 1: {args.image1}")
|
||||||
|
print(f" Image 2: {args.image2}")
|
||||||
|
print(f" Similarity score: {similarity:.4f}")
|
||||||
|
print(f" Interpretation: {'Similar' if similarity > 0.5 else 'Different'}")
|
||||||
|
|
||||||
|
return similarity
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
275
models/SiaN-similarity/train.py
Normal file
275
models/SiaN-similarity/train.py
Normal file
@@ -0,0 +1,275 @@
|
|||||||
|
"""
|
||||||
|
Training script for image similarity estimation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.optim as optim
|
||||||
|
from dataloader import create_data_loaders
|
||||||
|
from model import SimilarityCNN, SimilarityLoss, create_similarity_model
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
|
class SimilarityTrainer:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: nn.Module,
|
||||||
|
train_loader: DataLoader,
|
||||||
|
val_loader: DataLoader,
|
||||||
|
device: torch.device,
|
||||||
|
config: dict,
|
||||||
|
):
|
||||||
|
self.model = model.to(device)
|
||||||
|
self.train_loader = train_loader
|
||||||
|
self.val_loader = val_loader
|
||||||
|
self.device = device
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
self.criterion = SimilarityLoss()
|
||||||
|
self.optimizer = optim.Adam(
|
||||||
|
model.parameters(),
|
||||||
|
lr=config.get("learning_rate", 2e-4),
|
||||||
|
betas=(config.get("beta1", 0.5), config.get("beta2", 0.999)),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.writer = None
|
||||||
|
self.best_val_loss = float("inf")
|
||||||
|
self.epochs_without_improvement = 0
|
||||||
|
|
||||||
|
def train_epoch(self, epoch: int) -> dict:
|
||||||
|
self.model.train()
|
||||||
|
total_loss = 0
|
||||||
|
total_samples = 0
|
||||||
|
|
||||||
|
pbar = tqdm(self.train_loader, desc=f"Epoch {epoch}")
|
||||||
|
for batch_idx, batch in enumerate(pbar):
|
||||||
|
google_img = batch["google_img"].to(self.device)
|
||||||
|
yandex_img = batch["yandex_img"].to(self.device)
|
||||||
|
target = batch["same_domain"].float().to(self.device).unsqueeze(1)
|
||||||
|
|
||||||
|
self.optimizer.zero_grad()
|
||||||
|
|
||||||
|
output = self.model(google_img, yandex_img)
|
||||||
|
loss = self.criterion(output, target)
|
||||||
|
|
||||||
|
loss.backward()
|
||||||
|
self.optimizer.step()
|
||||||
|
|
||||||
|
total_loss += loss.item() * google_img.size(0)
|
||||||
|
total_samples += google_img.size(0)
|
||||||
|
|
||||||
|
if batch_idx % self.config.get("log_interval", 10) == 0:
|
||||||
|
metrics = self.criterion.compute_metrics(output, target)
|
||||||
|
pbar.set_postfix(
|
||||||
|
{
|
||||||
|
"loss": loss.item(),
|
||||||
|
"acc": metrics["accuracy"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.writer:
|
||||||
|
self.writer.add_scalar(
|
||||||
|
"train/loss",
|
||||||
|
loss.item(),
|
||||||
|
epoch * len(self.train_loader) + batch_idx,
|
||||||
|
)
|
||||||
|
self.writer.add_scalar(
|
||||||
|
"train/accuracy",
|
||||||
|
metrics["accuracy"],
|
||||||
|
epoch * len(self.train_loader) + batch_idx,
|
||||||
|
)
|
||||||
|
|
||||||
|
avg_loss = total_loss / total_samples
|
||||||
|
return {"loss": avg_loss}
|
||||||
|
|
||||||
|
def validate(self) -> dict:
|
||||||
|
self.model.eval()
|
||||||
|
total_loss = 0
|
||||||
|
total_samples = 0
|
||||||
|
all_metrics = []
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
for batch in tqdm(self.val_loader, desc="Validation"):
|
||||||
|
google_img = batch["google_img"].to(self.device)
|
||||||
|
yandex_img = batch["yandex_img"].to(self.device)
|
||||||
|
target = batch["same_domain"].float().to(self.device).unsqueeze(1)
|
||||||
|
|
||||||
|
output = self.model(google_img, yandex_img)
|
||||||
|
loss = self.criterion(output, target)
|
||||||
|
|
||||||
|
total_loss += loss.item() * google_img.size(0)
|
||||||
|
total_samples += google_img.size(0)
|
||||||
|
|
||||||
|
metrics = self.criterion.compute_metrics(output, target)
|
||||||
|
all_metrics.append(metrics)
|
||||||
|
|
||||||
|
avg_loss = total_loss / total_samples
|
||||||
|
|
||||||
|
avg_metrics = {}
|
||||||
|
for key in all_metrics[0].keys():
|
||||||
|
avg_metrics[key] = sum(m[key] for m in all_metrics) / len(all_metrics)
|
||||||
|
|
||||||
|
return {"loss": avg_loss, **avg_metrics}
|
||||||
|
|
||||||
|
def train(self, num_epochs: int):
|
||||||
|
log_dir = self.config.get("output_dir", "runs/similarity")
|
||||||
|
os.makedirs(log_dir, exist_ok=True)
|
||||||
|
self.writer = SummaryWriter(log_dir)
|
||||||
|
|
||||||
|
print(f"Starting training for {num_epochs} epochs")
|
||||||
|
print(f"Logging to: {log_dir}")
|
||||||
|
|
||||||
|
for epoch in range(1, num_epochs + 1):
|
||||||
|
print(f"\nEpoch {epoch}/{num_epochs}")
|
||||||
|
|
||||||
|
train_metrics = self.train_epoch(epoch)
|
||||||
|
val_metrics = self.validate()
|
||||||
|
|
||||||
|
print(f"Train Loss: {train_metrics['loss']:.4f}")
|
||||||
|
print(f"Val Loss: {val_metrics['loss']:.4f}")
|
||||||
|
print(f"Val Accuracy: {val_metrics['accuracy']:.4f}")
|
||||||
|
print(f"Val F1: {val_metrics['f1']:.4f}")
|
||||||
|
|
||||||
|
if self.writer:
|
||||||
|
self.writer.add_scalar("epoch/train_loss", train_metrics["loss"], epoch)
|
||||||
|
self.writer.add_scalar("epoch/val_loss", val_metrics["loss"], epoch)
|
||||||
|
self.writer.add_scalar(
|
||||||
|
"epoch/val_accuracy", val_metrics["accuracy"], epoch
|
||||||
|
)
|
||||||
|
|
||||||
|
if val_metrics["loss"] < self.best_val_loss:
|
||||||
|
self.best_val_loss = val_metrics["loss"]
|
||||||
|
self.epochs_without_improvement = 0
|
||||||
|
self.save_checkpoint(epoch, val_metrics["loss"], is_best=True)
|
||||||
|
print(f"New best model saved with val loss: {val_metrics['loss']:.4f}")
|
||||||
|
else:
|
||||||
|
self.epochs_without_improvement += 1
|
||||||
|
self.save_checkpoint(epoch, val_metrics["loss"], is_best=False)
|
||||||
|
|
||||||
|
patience = self.config.get("early_stopping_patience", 20)
|
||||||
|
if self.epochs_without_improvement >= patience:
|
||||||
|
print(
|
||||||
|
f"Early stopping triggered after {patience} epochs without improvement"
|
||||||
|
)
|
||||||
|
break
|
||||||
|
|
||||||
|
self.writer.close()
|
||||||
|
|
||||||
|
def save_checkpoint(self, epoch: int, val_loss: float, is_best: bool = False):
|
||||||
|
checkpoint_dir = os.path.join(
|
||||||
|
self.config.get("output_dir", "runs/similarity"), "checkpoints"
|
||||||
|
)
|
||||||
|
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||||
|
|
||||||
|
checkpoint = {
|
||||||
|
"epoch": epoch,
|
||||||
|
"model_state_dict": self.model.state_dict(),
|
||||||
|
"optimizer_state_dict": self.optimizer.state_dict(),
|
||||||
|
"val_loss": val_loss,
|
||||||
|
"config": self.config,
|
||||||
|
}
|
||||||
|
|
||||||
|
checkpoint_path = os.path.join(checkpoint_dir, f"checkpoint_epoch_{epoch}.pt")
|
||||||
|
torch.save(checkpoint, checkpoint_path)
|
||||||
|
|
||||||
|
if is_best:
|
||||||
|
best_path = os.path.join(checkpoint_dir, "best_model.pt")
|
||||||
|
torch.save(checkpoint, best_path)
|
||||||
|
|
||||||
|
def load_checkpoint(self, checkpoint_path: str):
|
||||||
|
checkpoint = torch.load(checkpoint_path, map_location=self.device)
|
||||||
|
self.model.load_state_dict(checkpoint["model_state_dict"])
|
||||||
|
self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
|
||||||
|
return checkpoint["epoch"], checkpoint["val_loss"]
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description="Train similarity estimation model")
|
||||||
|
parser.add_argument(
|
||||||
|
"--data_dir",
|
||||||
|
type=str,
|
||||||
|
default=r"C:\Users\admin\Projects\autopilot\datasets\ya_go_maps\images",
|
||||||
|
)
|
||||||
|
parser.add_argument("--batch_size", type=int, default=32)
|
||||||
|
parser.add_argument("--epochs", type=int, default=100)
|
||||||
|
parser.add_argument("--learning_rate", type=float, default=2e-4)
|
||||||
|
parser.add_argument("--image_size", type=int, default=256)
|
||||||
|
parser.add_argument("--train_split", type=float, default=0.8)
|
||||||
|
parser.add_argument("--output_dir", type=str, default="runs/similarity")
|
||||||
|
parser.add_argument("--num_workers", type=int, default=0)
|
||||||
|
parser.add_argument(
|
||||||
|
"--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
config = {
|
||||||
|
"data_dir": args.data_dir,
|
||||||
|
"batch_size": args.batch_size,
|
||||||
|
"epochs": args.epochs,
|
||||||
|
"learning_rate": args.learning_rate,
|
||||||
|
"image_size": (args.image_size, args.image_size),
|
||||||
|
"train_split": args.train_split,
|
||||||
|
"output_dir": args.output_dir,
|
||||||
|
"num_workers": args.num_workers,
|
||||||
|
"log_interval": 10,
|
||||||
|
"save_interval": 5,
|
||||||
|
"early_stopping_patience": 20,
|
||||||
|
"beta1": 0.5,
|
||||||
|
"beta2": 0.999,
|
||||||
|
}
|
||||||
|
|
||||||
|
device = torch.device(args.device)
|
||||||
|
print(f"Using device: {device}")
|
||||||
|
|
||||||
|
print("Creating data loaders...")
|
||||||
|
train_loader, val_loader = create_data_loaders(
|
||||||
|
root_dir=config["data_dir"],
|
||||||
|
batch_size=config["batch_size"],
|
||||||
|
train_split=config["train_split"],
|
||||||
|
num_workers=config["num_workers"],
|
||||||
|
image_size=config["image_size"],
|
||||||
|
augment_train=True,
|
||||||
|
augment_val=False,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Train batches: {len(train_loader)}")
|
||||||
|
print(f"Val batches: {len(val_loader)}")
|
||||||
|
|
||||||
|
print("Creating model...")
|
||||||
|
model = create_similarity_model(
|
||||||
|
model_type="cnn",
|
||||||
|
input_size=config["image_size"],
|
||||||
|
input_channels=3,
|
||||||
|
hidden_channels=64,
|
||||||
|
num_blocks=4,
|
||||||
|
dropout_rate=0.3,
|
||||||
|
use_batch_norm=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
|
||||||
|
|
||||||
|
trainer = SimilarityTrainer(
|
||||||
|
model=model,
|
||||||
|
train_loader=train_loader,
|
||||||
|
val_loader=val_loader,
|
||||||
|
device=device,
|
||||||
|
config=config,
|
||||||
|
)
|
||||||
|
|
||||||
|
print("Starting training...")
|
||||||
|
trainer.train(config["epochs"])
|
||||||
|
|
||||||
|
print("Training completed!")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
Reference in New Issue
Block a user