feat: add similarity model
This commit is contained in:
131
models/SiaN-similarity/README.md
Normal file
131
models/SiaN-similarity/README.md
Normal file
@@ -0,0 +1,131 @@
|
||||
# SiaN-Similarity: Модель для оценки схожести изображений
|
||||
|
||||
Модель для оценки схожести между двумя изображениями 256x256. Возвращает значение от 0 до 1, где 1 означает полную схожесть, 0 - полное различие.
|
||||
|
||||
## Архитектура модели
|
||||
|
||||
Модель основана на CNN с residual блоками:
|
||||
- Общий энкодер для обоих изображений
|
||||
- Residual blocks с batch normalization
|
||||
- Слой слияния признаков
|
||||
- Регрессионная голова с сигмоидой на выходе
|
||||
|
||||
## Использование
|
||||
|
||||
### Установка зависимостей
|
||||
```bash
|
||||
pip install torch torchvision pillow
|
||||
```
|
||||
|
||||
### Быстрый старт
|
||||
|
||||
```python
|
||||
import torch
|
||||
from model import SimilarityCNN
|
||||
|
||||
# Создание модели
|
||||
model = SimilarityCNN(
|
||||
input_channels=3,
|
||||
hidden_channels=64,
|
||||
num_blocks=4,
|
||||
dropout_rate=0.3,
|
||||
use_batch_norm=True,
|
||||
)
|
||||
|
||||
# Предсказание схожести
|
||||
img1 = torch.randn(1, 3, 256, 256) # Изображение 1
|
||||
img2 = torch.randn(1, 3, 256, 256) # Изображение 2
|
||||
|
||||
similarity = model.predict_similarity(img1, img2)
|
||||
print(f"Схожесть: {similarity.item():.4f}")
|
||||
```
|
||||
|
||||
### Обучение модели
|
||||
|
||||
```bash
|
||||
python train_similarity.py \
|
||||
--data_dir "путь/к/данным" \
|
||||
--batch_size 32 \
|
||||
--epochs 100 \
|
||||
--learning_rate 2e-4 \
|
||||
--output_dir "runs/similarity"
|
||||
```
|
||||
|
||||
### Предсказание на новых изображениях
|
||||
|
||||
```bash
|
||||
python predict.py \
|
||||
--image1 "путь/к/изображению1.png" \
|
||||
--image2 "путь/к/изображению2.png" \
|
||||
--checkpoint "runs/similarity/checkpoints/best_model.pt"
|
||||
```
|
||||
|
||||
## Структура проекта
|
||||
|
||||
```
|
||||
SiaN-similarity/
|
||||
├── model.py # Основная модель
|
||||
├── dataloader.py # Даталоадер для обучения
|
||||
├── train_similarity.py # Скрипт для обучения
|
||||
├── predict.py # Скрипт для предсказания
|
||||
├── train.py # Оригинальный тренировочный скрипт
|
||||
└── README.md # Этот файл
|
||||
```
|
||||
|
||||
## Конфигурация модели
|
||||
|
||||
Параметры по умолчанию:
|
||||
- `input_channels`: 3 (RGB)
|
||||
- `hidden_channels`: 64
|
||||
- `num_blocks`: 4
|
||||
- `dropout_rate`: 0.3
|
||||
- `use_batch_norm`: True
|
||||
- `image_size`: (256, 256)
|
||||
|
||||
## Формат данных
|
||||
|
||||
Модель ожидает изображения размером 256x256 пикселей в формате RGB.
|
||||
Для обучения используется датасет с парами изображений и метками схожести.
|
||||
|
||||
## Примеры использования
|
||||
|
||||
### 1. Создание и тестирование модели
|
||||
```python
|
||||
from model import create_similarity_model
|
||||
|
||||
model = create_similarity_model(
|
||||
model_type="cnn",
|
||||
input_size=(256, 256),
|
||||
hidden_channels=32,
|
||||
num_blocks=3,
|
||||
)
|
||||
```
|
||||
|
||||
### 2. Использование функции потерь
|
||||
```python
|
||||
from model import SimilarityLoss
|
||||
|
||||
loss_fn = SimilarityLoss()
|
||||
pred = torch.tensor([[0.8], [0.2]])
|
||||
target = torch.tensor([[1.0], [0.0]])
|
||||
loss = loss_fn(pred, target)
|
||||
```
|
||||
|
||||
### 3. Расчет метрик
|
||||
```python
|
||||
metrics = loss_fn.compute_metrics(pred, target)
|
||||
print(f"Accuracy: {metrics['accuracy']:.4f}")
|
||||
print(f"F1-score: {metrics['f1']:.4f}")
|
||||
```
|
||||
|
||||
## Требования
|
||||
|
||||
- Python 3.8+
|
||||
- PyTorch 1.9+
|
||||
- torchvision
|
||||
- Pillow
|
||||
- numpy
|
||||
|
||||
## Лицензия
|
||||
|
||||
MIT
|
||||
Reference in New Issue
Block a user