131 lines
3.8 KiB
Markdown
131 lines
3.8 KiB
Markdown
# SiaN-Similarity: Модель для оценки схожести изображений
|
||
|
||
Модель для оценки схожести между двумя изображениями 256x256. Возвращает значение от 0 до 1, где 1 означает полную схожесть, 0 - полное различие.
|
||
|
||
## Архитектура модели
|
||
|
||
Модель основана на CNN с residual блоками:
|
||
- Общий энкодер для обоих изображений
|
||
- Residual blocks с batch normalization
|
||
- Слой слияния признаков
|
||
- Регрессионная голова с сигмоидой на выходе
|
||
|
||
## Использование
|
||
|
||
### Установка зависимостей
|
||
```bash
|
||
pip install torch torchvision pillow
|
||
```
|
||
|
||
### Быстрый старт
|
||
|
||
```python
|
||
import torch
|
||
from model import SimilarityCNN
|
||
|
||
# Создание модели
|
||
model = SimilarityCNN(
|
||
input_channels=3,
|
||
hidden_channels=64,
|
||
num_blocks=4,
|
||
dropout_rate=0.3,
|
||
use_batch_norm=True,
|
||
)
|
||
|
||
# Предсказание схожести
|
||
img1 = torch.randn(1, 3, 256, 256) # Изображение 1
|
||
img2 = torch.randn(1, 3, 256, 256) # Изображение 2
|
||
|
||
similarity = model.predict_similarity(img1, img2)
|
||
print(f"Схожесть: {similarity.item():.4f}")
|
||
```
|
||
|
||
### Обучение модели
|
||
|
||
```bash
|
||
python train_similarity.py \
|
||
--data_dir "путь/к/данным" \
|
||
--batch_size 32 \
|
||
--epochs 100 \
|
||
--learning_rate 2e-4 \
|
||
--output_dir "runs/similarity"
|
||
```
|
||
|
||
### Предсказание на новых изображениях
|
||
|
||
```bash
|
||
python predict.py \
|
||
--image1 "путь/к/изображению1.png" \
|
||
--image2 "путь/к/изображению2.png" \
|
||
--checkpoint "runs/similarity/checkpoints/best_model.pt"
|
||
```
|
||
|
||
## Структура проекта
|
||
|
||
```
|
||
SiaN-similarity/
|
||
├── model.py # Основная модель
|
||
├── dataloader.py # Даталоадер для обучения
|
||
├── train_similarity.py # Скрипт для обучения
|
||
├── predict.py # Скрипт для предсказания
|
||
├── train.py # Оригинальный тренировочный скрипт
|
||
└── README.md # Этот файл
|
||
```
|
||
|
||
## Конфигурация модели
|
||
|
||
Параметры по умолчанию:
|
||
- `input_channels`: 3 (RGB)
|
||
- `hidden_channels`: 64
|
||
- `num_blocks`: 4
|
||
- `dropout_rate`: 0.3
|
||
- `use_batch_norm`: True
|
||
- `image_size`: (256, 256)
|
||
|
||
## Формат данных
|
||
|
||
Модель ожидает изображения размером 256x256 пикселей в формате RGB.
|
||
Для обучения используется датасет с парами изображений и метками схожести.
|
||
|
||
## Примеры использования
|
||
|
||
### 1. Создание и тестирование модели
|
||
```python
|
||
from model import create_similarity_model
|
||
|
||
model = create_similarity_model(
|
||
model_type="cnn",
|
||
input_size=(256, 256),
|
||
hidden_channels=32,
|
||
num_blocks=3,
|
||
)
|
||
```
|
||
|
||
### 2. Использование функции потерь
|
||
```python
|
||
from model import SimilarityLoss
|
||
|
||
loss_fn = SimilarityLoss()
|
||
pred = torch.tensor([[0.8], [0.2]])
|
||
target = torch.tensor([[1.0], [0.0]])
|
||
loss = loss_fn(pred, target)
|
||
```
|
||
|
||
### 3. Расчет метрик
|
||
```python
|
||
metrics = loss_fn.compute_metrics(pred, target)
|
||
print(f"Accuracy: {metrics['accuracy']:.4f}")
|
||
print(f"F1-score: {metrics['f1']:.4f}")
|
||
```
|
||
|
||
## Требования
|
||
|
||
- Python 3.8+
|
||
- PyTorch 1.9+
|
||
- torchvision
|
||
- Pillow
|
||
- numpy
|
||
|
||
## Лицензия
|
||
|
||
MIT |