feat: add similarity model

This commit is contained in:
2026-03-03 21:42:23 +03:00
parent 1de150b386
commit 43cd4222bc
7 changed files with 1801 additions and 0 deletions

View File

@@ -0,0 +1,131 @@
# SiaN-Similarity: Модель для оценки схожести изображений
Модель для оценки схожести между двумя изображениями 256x256. Возвращает значение от 0 до 1, где 1 означает полную схожесть, 0 - полное различие.
## Архитектура модели
Модель основана на CNN с residual блоками:
- Общий энкодер для обоих изображений
- Residual blocks с batch normalization
- Слой слияния признаков
- Регрессионная голова с сигмоидой на выходе
## Использование
### Установка зависимостей
```bash
pip install torch torchvision pillow
```
### Быстрый старт
```python
import torch
from model import SimilarityCNN
# Создание модели
model = SimilarityCNN(
input_channels=3,
hidden_channels=64,
num_blocks=4,
dropout_rate=0.3,
use_batch_norm=True,
)
# Предсказание схожести
img1 = torch.randn(1, 3, 256, 256) # Изображение 1
img2 = torch.randn(1, 3, 256, 256) # Изображение 2
similarity = model.predict_similarity(img1, img2)
print(f"Схожесть: {similarity.item():.4f}")
```
### Обучение модели
```bash
python train_similarity.py \
--data_dir "путь/к/данным" \
--batch_size 32 \
--epochs 100 \
--learning_rate 2e-4 \
--output_dir "runs/similarity"
```
### Предсказание на новых изображениях
```bash
python predict.py \
--image1 "путь/к/изображению1.png" \
--image2 "путь/к/изображению2.png" \
--checkpoint "runs/similarity/checkpoints/best_model.pt"
```
## Структура проекта
```
SiaN-similarity/
├── model.py # Основная модель
├── dataloader.py # Даталоадер для обучения
├── train_similarity.py # Скрипт для обучения
├── predict.py # Скрипт для предсказания
├── train.py # Оригинальный тренировочный скрипт
└── README.md # Этот файл
```
## Конфигурация модели
Параметры по умолчанию:
- `input_channels`: 3 (RGB)
- `hidden_channels`: 64
- `num_blocks`: 4
- `dropout_rate`: 0.3
- `use_batch_norm`: True
- `image_size`: (256, 256)
## Формат данных
Модель ожидает изображения размером 256x256 пикселей в формате RGB.
Для обучения используется датасет с парами изображений и метками схожести.
## Примеры использования
### 1. Создание и тестирование модели
```python
from model import create_similarity_model
model = create_similarity_model(
model_type="cnn",
input_size=(256, 256),
hidden_channels=32,
num_blocks=3,
)
```
### 2. Использование функции потерь
```python
from model import SimilarityLoss
loss_fn = SimilarityLoss()
pred = torch.tensor([[0.8], [0.2]])
target = torch.tensor([[1.0], [0.0]])
loss = loss_fn(pred, target)
```
### 3. Расчет метрик
```python
metrics = loss_fn.compute_metrics(pred, target)
print(f"Accuracy: {metrics['accuracy']:.4f}")
print(f"F1-score: {metrics['f1']:.4f}")
```
## Требования
- Python 3.8+
- PyTorch 1.9+
- torchvision
- Pillow
- numpy
## Лицензия
MIT

View File

@@ -0,0 +1,519 @@
config = {
# Параметры оптимизатора
"learning_rate": 2e-4,
"beta1": 0.5,
"beta2": 0.999,
# Параметры обучения
"batch_size": 4,
"epochs": 100,
# Параметры GAN
"gan_mode": "vanilla", # "vanilla", "lsgan", или "wgangp"
"lambda_L1": 100.0, # Вес L1 потерь
# Регуляризация
"grad_clip": 1.0,
# Ранняя остановка
"early_stopping_patience": 20,
# Выходные данные
"output_dir": "runs/gan_training",
# Логирование
"log_interval": 10, # Логировать каждые N батчей
"save_interval": 5, # Сохранять чекпоинт каждые N эпох
"data_dir": r"C:\Users\admin\Projects\autopilot\datasets\ya_go_maps\images",
"batch_size": 32,
"image_size": [256, 256],
"train_split": 0.8,
"num_workers": 0,
}
import os
import random
from typing import Any, Dict, List, Optional, Tuple
import cv2
import numpy as np
import torch
from PIL import Image
from torch.utils.data import DataLoader, Dataset
class YaGoDataset(Dataset):
"""
Dataset for homography estimation between Yandex and Google map image pairs.
This dataset loads pairs of images (Yandex and Google maps) and provides
homography matrices for data augmentation and training.
"""
def __init__(
self,
root_dir: str,
transform=None,
augment: bool = True,
max_samples: Optional[int] = None,
image_size: Tuple[int, int] = (700, 700),
cache_homographies: bool = True,
device: torch.device = None,
):
"""
Initialize the YaGoDataset.
Args:
root_dir: Directory containing image pairs (format: {idx:04d}_google.png, {idx:04d}_yandex.png)
transform: Optional torchvision transforms to apply
augment: Whether to apply homography-based data augmentation
max_samples: Maximum number of samples to load (None for all)
image_size: Target size for images (height, width)
cache_homographies: Whether to cache generated homography matrices to disk
"""
self.root_dir = root_dir
self.transform = transform
self.augment = augment
self.image_size = image_size
self.cache_homographies = cache_homographies
self.device = device
# Find all image pairs
self.image_pairs = self._discover_image_pairs()
if max_samples is not None:
self.image_pairs = self.image_pairs[:max_samples]
print(f"Found {len(self.image_pairs)} image pairs in {root_dir}")
def _discover_image_pairs(self) -> List[Dict[str, Any]]:
"""Discover all Google-Yandex image pairs in the dataset directory."""
image_pairs = []
# Get all Google images
google_files = [
f for f in os.listdir(self.root_dir) if f.endswith("_google.png")
]
for google_file in sorted(google_files):
# Extract index from filename
idx_str = google_file.split("_")[0]
try:
idx = int(idx_str)
except ValueError:
continue
# Check if corresponding Yandex image exists
yandex_file = f"{idx:04d}_yandex.png"
yandex_path = os.path.join(self.root_dir, yandex_file)
if os.path.exists(yandex_path):
image_pairs.append(
{
"idx": idx,
"google_path": os.path.join(self.root_dir, google_file),
"yandex_path": yandex_path,
}
)
return image_pairs
def __len__(self) -> int:
"""Return the number of image pairs in the dataset."""
return len(self.image_pairs)
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
"""
Get a sample from the dataset.
Returns a dictionary with:
- 'google_img': Google map image tensor
- 'yandex_img': Yandex map image tensor
- 'homography': Ground truth homography matrix (3x3)
- 'idx': Sample index
"""
pair_info = self.image_pairs[idx]
google_path = pair_info["google_path"]
yandex_path = pair_info["yandex_path"]
same_domain = True
if np.random.rand() > 0.5:
random_idx = np.random.randint(0, len(self))
google_path = self.image_pairs[random_idx]["google_path"]
same_domain = random_idx == idx
# Load images
yandex_img = Image.open(yandex_path).convert("RGB")
google_img = Image.open(google_path).convert("RGB")
# Resize images to target size
google_img = google_img.resize(
(self.image_size[1], self.image_size[0]), Image.BILINEAR
)
yandex_img = yandex_img.resize(
(self.image_size[1], self.image_size[0]), Image.BILINEAR
)
# Get or generate homography matrix
matrices: Tuple[np.ndarray, np.ndarray, np.ndarray] = (
self._get_homography_matrix(pair_info["idx"])
)
# Apply data augmentation if enabled
if self.augment:
google_img, yandex_img, homography_matrix = self._apply_augmentation(
google_img, yandex_img, matrices
)
# Convert images to tensors
if self.transform:
google_img = self.transform(google_img)
yandex_img = self.transform(yandex_img)
else:
# Default conversion to tensor
google_img = (
torch.from_numpy(np.array(google_img)).float().permute(2, 0, 1) / 255.0
)
yandex_img = (
torch.from_numpy(np.array(yandex_img)).float().permute(2, 0, 1) / 255.0
)
# Convert homography to tensor
if self.augment:
homography_tensor = torch.from_numpy(homography_matrix).float()
else:
homography_tensor = torch.from_numpy(np.eye(3))
return {
"google_img": google_img,
"yandex_img": yandex_img,
"homography": homography_tensor,
"same_domain": same_domain,
"idx": torch.tensor(pair_info["idx"], dtype=torch.long),
}
def _get_homography_matrix(
self, idx: int
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
Get homography matrices for a given index.
If cached homography exists, load it. Otherwise generate a new one.
"""
# Generate new homography matrix
homography_matrix_1 = self.generate_random_homography()
homography_matrix_2 = self.generate_random_homography()
homography_matrix_r = np.linalg.inv(homography_matrix_1) @ homography_matrix_2
result = (homography_matrix_1, homography_matrix_2, homography_matrix_r)
return result
def generate_random_homography(self) -> np.ndarray:
"""
Generate a random homography matrix for data augmentation.
Returns:
np.ndarray: 3x3 homography matrix.
"""
# Generate random affine transformation parameters
scale = np.random.uniform(0.8, 1.2) # scaling factor
tx = np.random.uniform(-0.50, 0.50) # translation in x
ty = np.random.uniform(-0.50, 0.50) # translation in y
# rotation
angle_x = np.random.uniform(np.radians(-10), np.radians(10))
angle_y = np.random.uniform(np.radians(-10), np.radians(10))
angle_z = np.random.uniform(np.radians(-10), np.radians(10))
cy, sy = np.cos(angle_z), np.sin(angle_z)
cp, sp = np.cos(angle_y), np.sin(angle_y)
cr, sr = np.cos(angle_x), np.sin(angle_x)
Rz = np.array(
[
[cy, -sy, 0],
[sy, cy, 0],
[0, 0, 1],
]
)
Ry = np.array(
[
[cp, 0, sp],
[0, 1, 0],
[-sp, 0, cp],
]
)
Rx = np.array(
[
[1, 0, 0],
[0, cr, -sr],
[0, sr, cr],
]
)
# Create affine transformation matrix
T = np.array(
[
[1, 0, tx],
[0, 1, ty],
[0, 0, scale],
]
)
K = self.get_camera_matrix()
return K @ Rx @ Ry @ Rz @ T @ np.linalg.inv(K)
def get_camera_matrix(self) -> np.ndarray:
w, h = config["image_size"]
K = np.array(
[
[w / 2, 0, w / 2],
[0, h / 2, h / 2],
[0, 0, 1],
]
)
return K
def _apply_augmentation(
self,
google_img: Image.Image,
yandex_img: Image.Image,
matrices: Tuple[np.ndarray, np.ndarray, np.ndarray],
) -> Tuple[Image.Image, Image.Image, np.ndarray]:
"""
Apply homography-based data augmentation to image pair.
Args:
google_img: Google map image
yandex_img: Yandex map image
matrices: homography matrices
Returns:
Tuple of (augmented_google_img, augmented_yandex_img, augmented_homography)
"""
# Combine with base homography
combined_homography = matrices[2]
# Apply augmentation to both images
# google_aug = self._apply_homography_to_image(google_img, aug_homography)
yandex_aug = self._apply_homography_to_image(yandex_img, matrices[0])
google_aug = self._apply_homography_to_image(google_img, matrices[1])
return google_aug, yandex_aug, combined_homography
def _apply_homography_to_image(
self, img: Image.Image, homography: np.ndarray
) -> Image.Image:
"""
Apply homography transformation to a single image.
Args:
img: PIL Image to transform
homography: 3x3 homography matrix
Returns:
Transformed PIL Image
"""
# Convert to numpy array
img_np = np.array(img)
# Get image dimensions
h, w = img_np.shape[:2]
# Apply homography transformation
transformed = cv2.warpPerspective(
img_np,
homography,
(w, h),
flags=cv2.INTER_LINEAR,
# borderMode=cv2.BORDER_REFLECT,
)
# Convert back to PIL Image
return Image.fromarray(transformed)
def get_sample_without_augmentation(self, idx: int) -> Dict[str, Any]:
"""
Get a sample without data augmentation.
Useful for visualization and evaluation.
"""
pair_info = self.image_pairs[idx]
# Load images
google_img = Image.open(pair_info["google_path"]).convert("RGB")
yandex_img = Image.open(pair_info["yandex_path"]).convert("RGB")
# Resize
google_img = google_img.resize(
(self.image_size[1], self.image_size[0]), Image.BILINEAR
)
yandex_img = yandex_img.resize(
(self.image_size[1], self.image_size[0]), Image.BILINEAR
)
# Get homography matrix
homography_matrix = self._get_homography_matrix(pair_info["idx"])
return {
"google_img": google_img,
"yandex_img": yandex_img,
"homography": homography_matrix,
"idx": pair_info["idx"],
"google_path": pair_info["google_path"],
"yandex_path": pair_info["yandex_path"],
}
def create_data_loaders(
root_dir: str,
batch_size: int = 32,
train_split: float = 0.8,
num_workers: int = 4,
image_size: Tuple[int, int] = (256, 256),
augment_train: bool = True,
augment_val: bool = False,
device: torch.device = None,
) -> Tuple[DataLoader, DataLoader]:
"""
Create train and validation data loaders for homography estimation.
Args:
root_dir: Directory containing image pairs
batch_size: Batch size for data loaders
train_split: Fraction of data to use for training
num_workers: Number of worker processes for data loading
image_size: Target image size (height, width)
augment_train: Whether to augment training data
augment_val: Whether to augment validation data
device: Target device for tensors (optional)
Returns:
Tuple of (train_loader, val_loader)
"""
from torchvision import transforms
# Define transforms
transform = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
)
# Create full dataset
full_dataset = YaGoDataset(
root_dir=root_dir,
transform=transform,
augment=False, # We'll handle augmentation separately
image_size=image_size,
cache_homographies=True,
device=device,
)
# Split dataset
dataset_size = len(full_dataset)
train_size = int(train_split * dataset_size)
val_size = dataset_size - train_size
# Create indices for splitting
indices = list(range(dataset_size))
random.shuffle(indices)
train_indices = indices[:train_size]
val_indices = indices[train_size:]
# Create subset samplers
from torch.utils.data import Subset
train_dataset = Subset(full_dataset, train_indices)
val_dataset = Subset(full_dataset, val_indices)
# Apply augmentation by overriding __getitem__ for train dataset
if augment_train:
class AugmentedSubset(Subset):
def __init__(self, dataset, indices, device=None):
super().__init__(dataset, indices)
self.device = device
def __getitem__(self, idx):
sample = self.dataset[self.indices[idx]]
# Apply augmentation
google_img = sample["google_img"]
yandex_img = sample["yandex_img"]
homography = sample["homography"]
if self.device is not None:
google_img = google_img.to(self.device)
yandex_img = yandex_img.to(self.device)
homography = homography.to(self.device)
# Generate augmentation homography
aug_homography = torch.from_numpy(
full_dataset.generate_random_homography()
).float()
if self.device is not None:
aug_homography = aug_homography.to(self.device)
# Combine homographies
combined_homography = aug_homography @ homography
# Apply augmentation (simplified - in practice would warp images)
# For now, we just return the combined homography
return {
"google_img": google_img,
"yandex_img": yandex_img,
"homography": combined_homography,
"idx": sample["idx"],
}
train_dataset = AugmentedSubset(full_dataset, train_indices, device=device)
# Create data loaders
train_loader = DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
pin_memory=True,
)
val_loader = DataLoader(
val_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
pin_memory=True,
)
return train_loader, val_loader
# Example usage
dataset = YaGoDataset(
root_dir=config["data_dir"],
augment=False,
image_size=(256, 256),
)
print(f"Dataset size: {len(dataset)}")
# Get a sample
sample = dataset[0]
print(f"Sample keys: {list(sample.keys())}")
print(f"Google image shape: {sample['google_img'].shape}")
print(f"Yandex image shape: {sample['yandex_img'].shape}")
print(f"Homography shape: {sample['homography'].shape}")
# Create data loaders
train_loader, val_loader = create_data_loaders(
root_dir=config["data_dir"],
batch_size=16,
train_split=0.8,
)
print(f"Train batches: {len(train_loader)}")
print(f"Val batches: {len(val_loader)}")

View File

@@ -0,0 +1,192 @@
"""
Демонстрационный скрипт для модели оценки схожести изображений.
"""
import matplotlib.pyplot as plt
import numpy as np
import torch
from model import SimilarityCNN, SimilarityLoss
from PIL import Image, ImageDraw, ImageFont
from torchvision import transforms
def create_test_images():
"""Создание тестовых изображений для демонстрации."""
images = []
# Изображение 1: Красный квадрат
img1 = Image.new("RGB", (256, 256), color="white")
draw = ImageDraw.Draw(img1)
draw.rectangle([50, 50, 200, 200], fill="red", outline="black", width=2)
images.append(("Красный квадрат", img1))
# Изображение 2: Тот же красный квадрат (похожее)
img2 = Image.new("RGB", (256, 256), color="white")
draw = ImageDraw.Draw(img2)
draw.rectangle([55, 55, 205, 205], fill="red", outline="black", width=2)
images.append(("Похожий красный квадрат", img2))
# Изображение 3: Синий круг (разное)
img3 = Image.new("RGB", (256, 256), color="white")
draw = ImageDraw.Draw(img3)
draw.ellipse([50, 50, 200, 200], fill="blue", outline="black", width=2)
images.append(("Синий круг", img3))
# Изображение 4: Зеленый треугольник (разное)
img4 = Image.new("RGB", (256, 256), color="white")
draw = ImageDraw.Draw(img4)
draw.polygon(
[(128, 50), (50, 200), (200, 200)], fill="green", outline="black", width=2
)
images.append(("Зеленый треугольник", img4))
return images
def preprocess_image(image):
"""Преобразование PIL Image в тензор."""
transform = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
)
return transform(image).unsqueeze(0) # Добавляем batch dimension
def display_results(images, similarities):
"""Отображение результатов сравнения."""
fig, axes = plt.subplots(2, 2, figsize=(12, 10))
axes = axes.flatten()
for idx, (title, img) in enumerate(images):
ax = axes[idx]
ax.imshow(img)
ax.set_title(title, fontsize=12, fontweight="bold")
ax.axis("off")
plt.suptitle("Тестовые изображения", fontsize=16, fontweight="bold")
plt.tight_layout()
plt.show()
# Вывод результатов сравнения
print("\n" + "=" * 60)
print("РЕЗУЛЬТАТЫ СРАВНЕНИЯ ИЗОБРАЖЕНИЙ")
print("=" * 60)
comparisons = [
("Красный квадрат", "Похожий красный квадрат"),
("Красный квадрат", "Синий круг"),
("Красный квадрат", "Зеленый треугольник"),
("Похожий красный квадрат", "Синий круг"),
]
for i, (name1, name2) in enumerate(comparisons):
idx1 = [idx for idx, (name, _) in enumerate(images) if name == name1][0]
idx2 = [idx for idx, (name, _) in enumerate(images) if name == name2][0]
sim = similarities[idx1, idx2]
interpretation = "ПОХОЖИ" if sim > 0.5 else "РАЗНЫЕ"
print(f"\n{name1} vs {name2}:")
print(f" Схожесть: {sim:.4f}")
print(f" Интерпретация: {interpretation}")
print(f" Уверенность: {'Высокая' if sim > 0.7 or sim < 0.3 else 'Средняя'}")
def test_loss_function():
"""Тестирование функции потерь."""
print("\n" + "=" * 60)
print("ТЕСТИРОВАНИЕ ФУНКЦИИ ПОТЕРЬ")
print("=" * 60)
loss_fn = SimilarityLoss()
# Тестовые данные
predictions = torch.tensor([[0.9], [0.1], [0.7], [0.3]])
targets = torch.tensor([[1.0], [0.0], [1.0], [0.0]])
# Расчет потерь
loss = loss_fn(predictions, targets)
print(f"\nПотери: {loss.item():.4f}")
# Расчет метрик
metrics = loss_fn.compute_metrics(predictions, targets)
print("\nМетрики:")
for key, value in metrics.items():
print(f" {key}: {value:.4f}")
def main():
"""Основная функция демонстрации."""
print("ДЕМОНСТРАЦИЯ МОДЕЛИ ОЦЕНКИ СХОЖЕСТИ ИЗОБРАЖЕНИЙ")
print("=" * 60)
# Создание модели
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"\nУстройство: {device}")
model = SimilarityCNN(
input_channels=3,
hidden_channels=64,
num_blocks=4,
dropout_rate=0.3,
use_batch_norm=True,
).to(device)
print(f"Параметры модели: {sum(p.numel() for p in model.parameters()):,}")
# Создание тестовых изображений
print("\nСоздание тестовых изображений...")
test_images = create_test_images()
# Преобразование изображений в тензоры
tensors = []
for name, img in test_images:
tensor = preprocess_image(img).to(device)
tensors.append(tensor)
# Расчет схожести между всеми парами изображений
print("\nРасчет схожести между изображениями...")
n_images = len(test_images)
similarity_matrix = np.zeros((n_images, n_images))
model.eval()
with torch.no_grad():
for i in range(n_images):
for j in range(n_images):
if i <= j: # Рассчитываем только верхний треугольник
sim = model.predict_similarity(tensors[i], tensors[j])
similarity_matrix[i, j] = sim.item()
similarity_matrix[j, i] = sim.item() # Симметричная матрица
# Отображение результатов
display_results(test_images, similarity_matrix)
# Тестирование функции потерь
test_loss_function()
# Дополнительная информация
print("\n" + "=" * 60)
print("ИНФОРМАЦИЯ О МОДЕЛИ")
print("=" * 60)
print("\nАрхитектура модели:")
print("-" * 40)
print("Вход: два изображения 256x256x3")
print("Энкодер: CNN с residual блоками")
print("Слой слияния: объединение признаков")
print("Выход: значение схожести [0, 1]")
print("\nИнтерпретация результатов:")
print("- 0.8-1.0: Очень похожи")
print("- 0.6-0.8: Похожи")
print("- 0.4-0.6: Нейтрально")
print("- 0.2-0.4: Разные")
print("- 0.0-0.2: Совершенно разные")
print("\n" + "=" * 60)
print("ДЕМОНСТРАЦИЯ ЗАВЕРШЕНА")
print("=" * 60)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,216 @@
"""
Пример использования модели оценки схожести с даталоадером.
"""
import torch
from dataloader import YaGoDataset, create_data_loaders
from model import SimilarityCNN, SimilarityLoss
def main():
"""Основной пример использования."""
print("ПРИМЕР ИСПОЛЬЗОВАНИЯ МОДЕЛИ СХОЖЕСТИ С ДАТАЛОАДЕРОМ")
print("=" * 60)
# Конфигурация
config = {
"data_dir": r"C:\Users\admin\Projects\autopilot\datasets\ya_go_maps\images",
"batch_size": 4,
"image_size": (256, 256),
"train_split": 0.8,
"num_workers": 0,
}
# Устройство
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Устройство: {device}")
# 1. Создание датасета
print("\n1. СОЗДАНИЕ ДАТАСЕТА")
print("-" * 40)
dataset = YaGoDataset(
root_dir=config["data_dir"],
augment=False,
image_size=config["image_size"],
)
print(f"Размер датасета: {len(dataset)} пар изображений")
# Получение примера из датасета
sample = dataset[0]
print(f"\nПример из датасета:")
print(f" Google image shape: {sample['google_img'].shape}")
print(f" Yandex image shape: {sample['yandex_img'].shape}")
print(f" Same domain: {sample['same_domain']}")
print(f" Index: {sample['idx'].item()}")
# 2. Создание даталоадеров
print("\n2. СОЗДАНИЕ ДАТАЛОАДЕРОВ")
print("-" * 40)
train_loader, val_loader = create_data_loaders(
root_dir=config["data_dir"],
batch_size=config["batch_size"],
train_split=config["train_split"],
num_workers=config["num_workers"],
image_size=config["image_size"],
augment_train=True,
augment_val=False,
device=device,
)
print(f"Train batches: {len(train_loader)}")
print(f"Val batches: {len(val_loader)}")
# 3. Создание модели
print("\n3. СОЗДАНИЕ МОДЕЛИ")
print("-" * 40)
model = SimilarityCNN(
input_channels=3,
hidden_channels=64,
num_blocks=4,
dropout_rate=0.3,
use_batch_norm=True,
).to(device)
print(f"Параметры модели: {sum(p.numel() for p in model.parameters()):,}")
# 4. Тестирование на одном батче
print("\n4. ТЕСТИРОВАНИЕ НА ОДНОМ БАТЧЕ")
print("-" * 40)
# Получаем батч из train_loader
for batch in train_loader:
google_img = batch["google_img"].to(device)
yandex_img = batch["yandex_img"].to(device)
same_domain = batch["same_domain"].float().to(device).unsqueeze(1)
print(f"Batch size: {google_img.shape[0]}")
print(f"Image shape: {google_img.shape[1:]}")
print(f"Same domain labels: {same_domain.squeeze().tolist()}")
# Предсказание схожести
with torch.no_grad():
predictions = model.predict_similarity(google_img, yandex_img)
print(f"\nПредсказания схожести:")
for i in range(len(predictions)):
print(
f" Sample {i}: {predictions[i].item():.4f} (target: {same_domain[i].item():.1f})"
)
# Расчет потерь
loss_fn = SimilarityLoss().to(device)
loss = loss_fn(predictions, same_domain)
print(f"\nПотери на батче: {loss.item():.4f}")
# Расчет метрик
metrics = loss_fn.compute_metrics(predictions, same_domain)
print("\nМетрики на батче:")
for key, value in metrics.items():
print(f" {key}: {value:.4f}")
break # Только первый батч
# 5. Обучение на одном эпохе (демонстрация)
print("\n5. ДЕМОНСТРАЦИЯ ОБУЧЕНИЯ НА ОДНОЙ ЭПОХЕ")
print("-" * 40)
optimizer = torch.optim.Adam(model.parameters(), lr=2e-4)
model.train()
total_loss = 0
total_samples = 0
for batch_idx, batch in enumerate(train_loader):
if batch_idx >= 3: # Ограничиваем 3 батчами для демонстрации
break
google_img = batch["google_img"].to(device)
yandex_img = batch["yandex_img"].to(device)
same_domain = batch["same_domain"].float().to(device).unsqueeze(1)
optimizer.zero_grad()
predictions = model(google_img, yandex_img)
loss = loss_fn(predictions, same_domain)
loss.backward()
optimizer.step()
batch_loss = loss.item() * google_img.size(0)
total_loss += batch_loss
total_samples += google_img.size(0)
print(f"Batch {batch_idx + 1}: loss = {loss.item():.4f}")
avg_loss = total_loss / total_samples
print(f"\nСредние потери за 3 батча: {avg_loss:.4f}")
# 6. Валидация
print("\n6. ВАЛИДАЦИЯ")
print("-" * 40)
model.eval()
val_loss = 0
val_samples = 0
with torch.no_grad():
for batch_idx, batch in enumerate(val_loader):
if batch_idx >= 2: # Ограничиваем 2 батчами для демонстрации
break
google_img = batch["google_img"].to(device)
yandex_img = batch["yandex_img"].to(device)
same_domain = batch["same_domain"].float().to(device).unsqueeze(1)
predictions = model.predict_similarity(google_img, yandex_img)
loss = loss_fn(predictions, same_domain)
val_loss += loss.item() * google_img.size(0)
val_samples += google_img.size(0)
print(f"Val batch {batch_idx + 1}: loss = {loss.item():.4f}")
avg_val_loss = val_loss / val_samples
print(f"\nСредние потери на валидации: {avg_val_loss:.4f}")
# 7. Пример использования для отдельных изображений
print("\n7. ПРИМЕР ДЛЯ ОТДЕЛЬНЫХ ИЗОБРАЖЕНИЙ")
print("-" * 40)
# Берем два примера из датасета
sample1 = dataset[0]
sample2 = dataset[1]
# Подготавливаем тензоры
img1_1 = sample1["google_img"].unsqueeze(0).to(device)
img1_2 = sample1["yandex_img"].unsqueeze(0).to(device)
img2_1 = sample2["google_img"].unsqueeze(0).to(device)
img2_2 = sample2["yandex_img"].unsqueeze(0).to(device)
# Предсказания
with torch.no_grad():
# Сравнение пар из одного домена
sim_same1 = model.predict_similarity(img1_1, img1_2)
sim_same2 = model.predict_similarity(img2_1, img2_2)
# Сравнение пар из разных доменов
sim_diff1 = model.predict_similarity(img1_1, img2_2)
sim_diff2 = model.predict_similarity(img2_1, img1_2)
print("Сравнение пар изображений:")
print(f" Пара 1 (один домен): {sim_same1.item():.4f}")
print(f" Пара 2 (один домен): {sim_same2.item():.4f}")
print(f" Разные домены 1: {sim_diff1.item():.4f}")
print(f" Разные домены 2: {sim_diff2.item():.4f}")
print("\n" + "=" * 60)
print("ПРИМЕР ЗАВЕРШЕН")
print("=" * 60)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,322 @@
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
class SimilarityCNN(nn.Module):
"""
CNN model for similarity estimation between two images.
Takes two images as input and outputs a similarity score between 0 and 1.
"""
def __init__(
self,
input_channels: int = 3,
hidden_channels: int = 64,
num_blocks: int = 4,
dropout_rate: float = 0.3,
use_batch_norm: bool = True,
):
super().__init__()
self.input_channels = input_channels
self.hidden_channels = hidden_channels
self.num_blocks = num_blocks
self.dropout_rate = dropout_rate
self.use_batch_norm = use_batch_norm
self.encoder = self._build_encoder()
self.fusion_layers = self._build_fusion_layers()
self.regression_head = self._build_regression_head()
self._initialize_weights()
def _build_encoder(self) -> nn.Module:
layers = []
in_channels = self.input_channels
out_channels = self.hidden_channels
layers.append(
nn.Conv2d(in_channels, out_channels, kernel_size=7, stride=2, padding=3)
)
if self.use_batch_norm:
layers.append(nn.BatchNorm2d(out_channels))
layers.append(nn.ReLU(inplace=True))
layers.append(nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
for i in range(self.num_blocks):
block_in_channels = out_channels
block_out_channels = out_channels * 2 if i < 2 else out_channels
layers.append(
ResidualBlock(
in_channels=block_in_channels,
out_channels=block_out_channels,
stride=1 if i == 0 else 2,
dropout_rate=self.dropout_rate,
use_batch_norm=self.use_batch_norm,
)
)
if i < 2:
out_channels = block_out_channels
return nn.Sequential(*layers)
def _build_fusion_layers(self) -> nn.Module:
fused_channels = self.hidden_channels * 8
layers = [
nn.Conv2d(
fused_channels, self.hidden_channels * 4, kernel_size=3, padding=1
),
nn.BatchNorm2d(self.hidden_channels * 4)
if self.use_batch_norm
else nn.Identity(),
nn.ReLU(inplace=True),
nn.Dropout2d(self.dropout_rate),
nn.Conv2d(
self.hidden_channels * 4,
self.hidden_channels * 2,
kernel_size=3,
padding=1,
),
nn.BatchNorm2d(self.hidden_channels * 2)
if self.use_batch_norm
else nn.Identity(),
nn.ReLU(inplace=True),
nn.Dropout2d(self.dropout_rate),
nn.AdaptiveAvgPool2d((1, 1)),
]
return nn.Sequential(*layers)
def _build_regression_head(self) -> nn.Module:
input_features = self.hidden_channels * 2
layers = [
nn.Flatten(),
nn.Linear(input_features, 512),
nn.BatchNorm1d(512) if self.use_batch_norm else nn.Identity(),
nn.ReLU(inplace=True),
nn.Dropout(self.dropout_rate),
nn.Linear(512, 256),
nn.BatchNorm1d(256) if self.use_batch_norm else nn.Identity(),
nn.ReLU(inplace=True),
nn.Dropout(self.dropout_rate),
nn.Linear(256, 128),
nn.BatchNorm1d(128) if self.use_batch_norm else nn.Identity(),
nn.ReLU(inplace=True),
nn.Dropout(self.dropout_rate),
nn.Linear(128, 1),
nn.Sigmoid(),
]
return nn.Sequential(*layers)
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.constant_(m.bias, 0)
def forward(
self,
img1: torch.Tensor,
img2: torch.Tensor,
) -> torch.Tensor:
features1 = self.encoder(img1)
features2 = self.encoder(img2)
combined_features = torch.cat([features1, features2], dim=1)
fused_features = self.fusion_layers(combined_features)
similarity = self.regression_head(fused_features)
return similarity
def predict_similarity(
self,
img1: torch.Tensor,
img2: torch.Tensor,
) -> torch.Tensor:
original_training = self.training
self.eval()
with torch.no_grad():
similarity = self.forward(img1, img2)
if original_training:
self.train()
return similarity
class ResidualBlock(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
stride: int = 1,
dropout_rate: float = 0.3,
use_batch_norm: bool = True,
):
super().__init__()
self.conv1 = nn.Conv2d(
in_channels,
out_channels,
kernel_size=3,
stride=stride,
padding=1,
bias=False,
)
self.bn1 = nn.BatchNorm2d(out_channels) if use_batch_norm else nn.Identity()
self.relu1 = nn.ReLU(inplace=True)
self.dropout1 = (
nn.Dropout2d(dropout_rate) if dropout_rate > 0 else nn.Identity()
)
self.conv2 = nn.Conv2d(
out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False
)
self.bn2 = nn.BatchNorm2d(out_channels) if use_batch_norm else nn.Identity()
self.relu2 = nn.ReLU(inplace=True)
self.dropout2 = (
nn.Dropout2d(dropout_rate) if dropout_rate > 0 else nn.Identity()
)
self.shortcut = nn.Sequential()
if stride != 1 or in_channels != out_channels:
self.shortcut = nn.Sequential(
nn.Conv2d(
in_channels, out_channels, kernel_size=1, stride=stride, bias=False
),
nn.BatchNorm2d(out_channels) if use_batch_norm else nn.Identity(),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
identity = self.shortcut(x)
out = self.conv1(x)
out = self.bn1(out)
out = self.relu1(out)
out = self.dropout1(out)
out = self.conv2(out)
out = self.bn2(out)
out += identity
out = self.relu2(out)
out = self.dropout2(out)
return out
class SimilarityLoss(nn.Module):
def __init__(self):
super().__init__()
self.criterion = nn.BCELoss()
def forward(
self,
pred_similarity: torch.Tensor,
target_same: torch.Tensor,
) -> torch.Tensor:
return self.criterion(pred_similarity, target_same)
def compute_metrics(
self,
pred_similarity: torch.Tensor,
target_same: torch.Tensor,
threshold: float = 0.5,
) -> dict:
with torch.no_grad():
pred_binary = (pred_similarity > threshold).float()
target_binary = (target_same > 0.5).float()
correct = (pred_binary == target_binary).float()
accuracy = correct.mean().item()
tp = ((pred_binary == 1) & (target_binary == 1)).float().sum().item()
fp = ((pred_binary == 1) & (target_binary == 0)).float().sum().item()
fn = ((pred_binary == 0) & (target_binary == 1)).float().sum().item()
tn = ((pred_binary == 0) & (target_binary == 0)).float().sum().item()
precision = tp / (tp + fp + 1e-8)
recall = tp / (tp + fn + 1e-8)
f1 = 2 * precision * recall / (precision + recall + 1e-8)
return {
"accuracy": accuracy,
"precision": precision,
"recall": recall,
"f1": f1,
"mean_similarity": pred_similarity.mean().item(),
}
def create_similarity_model(
model_type: str = "cnn",
input_size: Tuple[int, int] = (256, 256),
**kwargs,
) -> nn.Module:
if model_type == "cnn":
return SimilarityCNN(**kwargs)
else:
raise ValueError(f"Unknown model type: {model_type}")
if __name__ == "__main__":
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
model = SimilarityCNN(
input_channels=3,
hidden_channels=64,
num_blocks=4,
dropout_rate=0.3,
use_batch_norm=True,
).to(device)
print(
f"Model created with {sum(p.numel() for p in model.parameters()):,} parameters"
)
batch_size = 4
height, width = 256, 256
img1 = torch.randn(batch_size, 3, height, width).to(device)
img2 = torch.randn(batch_size, 3, height, width).to(device)
print("\nTesting forward pass...")
output = model(img1, img2)
print(f"Output shape: {output.shape}")
print(f"Sample output: {output[0].item():.4f}")
print("\nTesting prediction...")
pred = model.predict_similarity(img1, img2)
print(f"Prediction shape: {pred.shape}")
print("\nTesting loss function...")
target = torch.rand(batch_size, 1).to(device)
loss_fn = SimilarityLoss().to(device)
loss = loss_fn(output, target)
print(f"Loss value: {loss.item():.6f}")
print("\nTesting metrics...")
metrics = loss_fn.compute_metrics(output, target)
for key, value in metrics.items():
print(f"{key}: {value:.6f}")
print("\nAll tests completed successfully!")

View File

@@ -0,0 +1,146 @@
"""
Script for predicting similarity between two images.
"""
import argparse
import os
from pathlib import Path
import torch
from model import SimilarityCNN
from PIL import Image
from torchvision import transforms
def load_image(image_path: str, image_size: tuple = (256, 256)) -> torch.Tensor:
"""Load and preprocess image."""
transform = transforms.Compose(
[
transforms.Resize(image_size),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
)
image = Image.open(image_path).convert("RGB")
return transform(image).unsqueeze(0) # Add batch dimension
def predict_similarity(
model: SimilarityCNN,
image1_path: str,
image2_path: str,
device: torch.device,
image_size: tuple = (256, 256),
) -> float:
"""Predict similarity between two images."""
model.eval()
img1 = load_image(image1_path, image_size).to(device)
img2 = load_image(image2_path, image_size).to(device)
with torch.no_grad():
similarity = model(img1, img2)
return similarity.item()
def load_model(
checkpoint_path: str,
device: torch.device,
**model_kwargs,
) -> SimilarityCNN:
"""Load model from checkpoint."""
model = SimilarityCNN(**model_kwargs).to(device)
checkpoint = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(checkpoint["model_state_dict"])
return model
def main():
parser = argparse.ArgumentParser(
description="Predict similarity between two images"
)
parser.add_argument("--image1", type=str, required=True, help="Path to first image")
parser.add_argument(
"--image2", type=str, required=True, help="Path to second image"
)
parser.add_argument(
"--checkpoint",
type=str,
default="runs/similarity/checkpoints/best_model.pt",
help="Path to model checkpoint",
)
parser.add_argument(
"--device",
type=str,
default="cuda" if torch.cuda.is_available() else "cpu",
help="Device to use for inference",
)
parser.add_argument(
"--image_size",
type=int,
default=256,
help="Image size for model input",
)
args = parser.parse_args()
device = torch.device(args.device)
print(f"Using device: {device}")
if not os.path.exists(args.image1):
print(f"Error: Image not found: {args.image1}")
return
if not os.path.exists(args.image2):
print(f"Error: Image not found: {args.image2}")
return
if not os.path.exists(args.checkpoint):
print(f"Warning: Checkpoint not found: {args.checkpoint}")
print("Using randomly initialized model for demonstration")
model = SimilarityCNN(
input_channels=3,
hidden_channels=64,
num_blocks=4,
dropout_rate=0.3,
use_batch_norm=True,
).to(device)
else:
print(f"Loading model from: {args.checkpoint}")
model = load_model(
checkpoint_path=args.checkpoint,
device=device,
input_channels=3,
hidden_channels=64,
num_blocks=4,
dropout_rate=0.3,
use_batch_norm=True,
)
print(
f"Model loaded with {sum(p.numel() for p in model.parameters()):,} parameters"
)
similarity = predict_similarity(
model=model,
image1_path=args.image1,
image2_path=args.image2,
device=device,
image_size=(args.image_size, args.image_size),
)
print(f"\nSimilarity between images:")
print(f" Image 1: {args.image1}")
print(f" Image 2: {args.image2}")
print(f" Similarity score: {similarity:.4f}")
print(f" Interpretation: {'Similar' if similarity > 0.5 else 'Different'}")
return similarity
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,275 @@
"""
Training script for image similarity estimation.
"""
import argparse
import os
import time
from datetime import datetime
import torch
import torch.nn as nn
import torch.optim as optim
from dataloader import create_data_loaders
from model import SimilarityCNN, SimilarityLoss, create_similarity_model
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
class SimilarityTrainer:
def __init__(
self,
model: nn.Module,
train_loader: DataLoader,
val_loader: DataLoader,
device: torch.device,
config: dict,
):
self.model = model.to(device)
self.train_loader = train_loader
self.val_loader = val_loader
self.device = device
self.config = config
self.criterion = SimilarityLoss()
self.optimizer = optim.Adam(
model.parameters(),
lr=config.get("learning_rate", 2e-4),
betas=(config.get("beta1", 0.5), config.get("beta2", 0.999)),
)
self.writer = None
self.best_val_loss = float("inf")
self.epochs_without_improvement = 0
def train_epoch(self, epoch: int) -> dict:
self.model.train()
total_loss = 0
total_samples = 0
pbar = tqdm(self.train_loader, desc=f"Epoch {epoch}")
for batch_idx, batch in enumerate(pbar):
google_img = batch["google_img"].to(self.device)
yandex_img = batch["yandex_img"].to(self.device)
target = batch["same_domain"].float().to(self.device).unsqueeze(1)
self.optimizer.zero_grad()
output = self.model(google_img, yandex_img)
loss = self.criterion(output, target)
loss.backward()
self.optimizer.step()
total_loss += loss.item() * google_img.size(0)
total_samples += google_img.size(0)
if batch_idx % self.config.get("log_interval", 10) == 0:
metrics = self.criterion.compute_metrics(output, target)
pbar.set_postfix(
{
"loss": loss.item(),
"acc": metrics["accuracy"],
}
)
if self.writer:
self.writer.add_scalar(
"train/loss",
loss.item(),
epoch * len(self.train_loader) + batch_idx,
)
self.writer.add_scalar(
"train/accuracy",
metrics["accuracy"],
epoch * len(self.train_loader) + batch_idx,
)
avg_loss = total_loss / total_samples
return {"loss": avg_loss}
def validate(self) -> dict:
self.model.eval()
total_loss = 0
total_samples = 0
all_metrics = []
with torch.no_grad():
for batch in tqdm(self.val_loader, desc="Validation"):
google_img = batch["google_img"].to(self.device)
yandex_img = batch["yandex_img"].to(self.device)
target = batch["same_domain"].float().to(self.device).unsqueeze(1)
output = self.model(google_img, yandex_img)
loss = self.criterion(output, target)
total_loss += loss.item() * google_img.size(0)
total_samples += google_img.size(0)
metrics = self.criterion.compute_metrics(output, target)
all_metrics.append(metrics)
avg_loss = total_loss / total_samples
avg_metrics = {}
for key in all_metrics[0].keys():
avg_metrics[key] = sum(m[key] for m in all_metrics) / len(all_metrics)
return {"loss": avg_loss, **avg_metrics}
def train(self, num_epochs: int):
log_dir = self.config.get("output_dir", "runs/similarity")
os.makedirs(log_dir, exist_ok=True)
self.writer = SummaryWriter(log_dir)
print(f"Starting training for {num_epochs} epochs")
print(f"Logging to: {log_dir}")
for epoch in range(1, num_epochs + 1):
print(f"\nEpoch {epoch}/{num_epochs}")
train_metrics = self.train_epoch(epoch)
val_metrics = self.validate()
print(f"Train Loss: {train_metrics['loss']:.4f}")
print(f"Val Loss: {val_metrics['loss']:.4f}")
print(f"Val Accuracy: {val_metrics['accuracy']:.4f}")
print(f"Val F1: {val_metrics['f1']:.4f}")
if self.writer:
self.writer.add_scalar("epoch/train_loss", train_metrics["loss"], epoch)
self.writer.add_scalar("epoch/val_loss", val_metrics["loss"], epoch)
self.writer.add_scalar(
"epoch/val_accuracy", val_metrics["accuracy"], epoch
)
if val_metrics["loss"] < self.best_val_loss:
self.best_val_loss = val_metrics["loss"]
self.epochs_without_improvement = 0
self.save_checkpoint(epoch, val_metrics["loss"], is_best=True)
print(f"New best model saved with val loss: {val_metrics['loss']:.4f}")
else:
self.epochs_without_improvement += 1
self.save_checkpoint(epoch, val_metrics["loss"], is_best=False)
patience = self.config.get("early_stopping_patience", 20)
if self.epochs_without_improvement >= patience:
print(
f"Early stopping triggered after {patience} epochs without improvement"
)
break
self.writer.close()
def save_checkpoint(self, epoch: int, val_loss: float, is_best: bool = False):
checkpoint_dir = os.path.join(
self.config.get("output_dir", "runs/similarity"), "checkpoints"
)
os.makedirs(checkpoint_dir, exist_ok=True)
checkpoint = {
"epoch": epoch,
"model_state_dict": self.model.state_dict(),
"optimizer_state_dict": self.optimizer.state_dict(),
"val_loss": val_loss,
"config": self.config,
}
checkpoint_path = os.path.join(checkpoint_dir, f"checkpoint_epoch_{epoch}.pt")
torch.save(checkpoint, checkpoint_path)
if is_best:
best_path = os.path.join(checkpoint_dir, "best_model.pt")
torch.save(checkpoint, best_path)
def load_checkpoint(self, checkpoint_path: str):
checkpoint = torch.load(checkpoint_path, map_location=self.device)
self.model.load_state_dict(checkpoint["model_state_dict"])
self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
return checkpoint["epoch"], checkpoint["val_loss"]
def main():
parser = argparse.ArgumentParser(description="Train similarity estimation model")
parser.add_argument(
"--data_dir",
type=str,
default=r"C:\Users\admin\Projects\autopilot\datasets\ya_go_maps\images",
)
parser.add_argument("--batch_size", type=int, default=32)
parser.add_argument("--epochs", type=int, default=100)
parser.add_argument("--learning_rate", type=float, default=2e-4)
parser.add_argument("--image_size", type=int, default=256)
parser.add_argument("--train_split", type=float, default=0.8)
parser.add_argument("--output_dir", type=str, default="runs/similarity")
parser.add_argument("--num_workers", type=int, default=0)
parser.add_argument(
"--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu"
)
args = parser.parse_args()
config = {
"data_dir": args.data_dir,
"batch_size": args.batch_size,
"epochs": args.epochs,
"learning_rate": args.learning_rate,
"image_size": (args.image_size, args.image_size),
"train_split": args.train_split,
"output_dir": args.output_dir,
"num_workers": args.num_workers,
"log_interval": 10,
"save_interval": 5,
"early_stopping_patience": 20,
"beta1": 0.5,
"beta2": 0.999,
}
device = torch.device(args.device)
print(f"Using device: {device}")
print("Creating data loaders...")
train_loader, val_loader = create_data_loaders(
root_dir=config["data_dir"],
batch_size=config["batch_size"],
train_split=config["train_split"],
num_workers=config["num_workers"],
image_size=config["image_size"],
augment_train=True,
augment_val=False,
device=device,
)
print(f"Train batches: {len(train_loader)}")
print(f"Val batches: {len(val_loader)}")
print("Creating model...")
model = create_similarity_model(
model_type="cnn",
input_size=config["image_size"],
input_channels=3,
hidden_channels=64,
num_blocks=4,
dropout_rate=0.3,
use_batch_norm=True,
)
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
trainer = SimilarityTrainer(
model=model,
train_loader=train_loader,
val_loader=val_loader,
device=device,
config=config,
)
print("Starting training...")
trainer.train(config["epochs"])
print("Training completed!")
if __name__ == "__main__":
main()