feat: add similarity model

This commit is contained in:
2026-03-03 21:42:23 +03:00
parent 1de150b386
commit 43cd4222bc
7 changed files with 1801 additions and 0 deletions

View File

@@ -0,0 +1,131 @@
# SiaN-Similarity: Модель для оценки схожести изображений
Модель для оценки схожести между двумя изображениями 256x256. Возвращает значение от 0 до 1, где 1 означает полную схожесть, 0 - полное различие.
## Архитектура модели
Модель основана на CNN с residual блоками:
- Общий энкодер для обоих изображений
- Residual blocks с batch normalization
- Слой слияния признаков
- Регрессионная голова с сигмоидой на выходе
## Использование
### Установка зависимостей
```bash
pip install torch torchvision pillow
```
### Быстрый старт
```python
import torch
from model import SimilarityCNN
# Создание модели
model = SimilarityCNN(
input_channels=3,
hidden_channels=64,
num_blocks=4,
dropout_rate=0.3,
use_batch_norm=True,
)
# Предсказание схожести
img1 = torch.randn(1, 3, 256, 256) # Изображение 1
img2 = torch.randn(1, 3, 256, 256) # Изображение 2
similarity = model.predict_similarity(img1, img2)
print(f"Схожесть: {similarity.item():.4f}")
```
### Обучение модели
```bash
python train_similarity.py \
--data_dir "путь/к/данным" \
--batch_size 32 \
--epochs 100 \
--learning_rate 2e-4 \
--output_dir "runs/similarity"
```
### Предсказание на новых изображениях
```bash
python predict.py \
--image1 "путь/к/изображению1.png" \
--image2 "путь/к/изображению2.png" \
--checkpoint "runs/similarity/checkpoints/best_model.pt"
```
## Структура проекта
```
SiaN-similarity/
├── model.py # Основная модель
├── dataloader.py # Даталоадер для обучения
├── train_similarity.py # Скрипт для обучения
├── predict.py # Скрипт для предсказания
├── train.py # Оригинальный тренировочный скрипт
└── README.md # Этот файл
```
## Конфигурация модели
Параметры по умолчанию:
- `input_channels`: 3 (RGB)
- `hidden_channels`: 64
- `num_blocks`: 4
- `dropout_rate`: 0.3
- `use_batch_norm`: True
- `image_size`: (256, 256)
## Формат данных
Модель ожидает изображения размером 256x256 пикселей в формате RGB.
Для обучения используется датасет с парами изображений и метками схожести.
## Примеры использования
### 1. Создание и тестирование модели
```python
from model import create_similarity_model
model = create_similarity_model(
model_type="cnn",
input_size=(256, 256),
hidden_channels=32,
num_blocks=3,
)
```
### 2. Использование функции потерь
```python
from model import SimilarityLoss
loss_fn = SimilarityLoss()
pred = torch.tensor([[0.8], [0.2]])
target = torch.tensor([[1.0], [0.0]])
loss = loss_fn(pred, target)
```
### 3. Расчет метрик
```python
metrics = loss_fn.compute_metrics(pred, target)
print(f"Accuracy: {metrics['accuracy']:.4f}")
print(f"F1-score: {metrics['f1']:.4f}")
```
## Требования
- Python 3.8+
- PyTorch 1.9+
- torchvision
- Pillow
- numpy
## Лицензия
MIT