3.8 KiB
3.8 KiB
SiaN-Similarity: Модель для оценки схожести изображений
Модель для оценки схожести между двумя изображениями 256x256. Возвращает значение от 0 до 1, где 1 означает полную схожесть, 0 - полное различие.
Архитектура модели
Модель основана на CNN с residual блоками:
- Общий энкодер для обоих изображений
- Residual blocks с batch normalization
- Слой слияния признаков
- Регрессионная голова с сигмоидой на выходе
Использование
Установка зависимостей
pip install torch torchvision pillow
Быстрый старт
import torch
from model import SimilarityCNN
# Создание модели
model = SimilarityCNN(
input_channels=3,
hidden_channels=64,
num_blocks=4,
dropout_rate=0.3,
use_batch_norm=True,
)
# Предсказание схожести
img1 = torch.randn(1, 3, 256, 256) # Изображение 1
img2 = torch.randn(1, 3, 256, 256) # Изображение 2
similarity = model.predict_similarity(img1, img2)
print(f"Схожесть: {similarity.item():.4f}")
Обучение модели
python train_similarity.py \
--data_dir "путь/к/данным" \
--batch_size 32 \
--epochs 100 \
--learning_rate 2e-4 \
--output_dir "runs/similarity"
Предсказание на новых изображениях
python predict.py \
--image1 "путь/к/изображению1.png" \
--image2 "путь/к/изображению2.png" \
--checkpoint "runs/similarity/checkpoints/best_model.pt"
Структура проекта
SiaN-similarity/
├── model.py # Основная модель
├── dataloader.py # Даталоадер для обучения
├── train_similarity.py # Скрипт для обучения
├── predict.py # Скрипт для предсказания
├── train.py # Оригинальный тренировочный скрипт
└── README.md # Этот файл
Конфигурация модели
Параметры по умолчанию:
input_channels: 3 (RGB)hidden_channels: 64num_blocks: 4dropout_rate: 0.3use_batch_norm: Trueimage_size: (256, 256)
Формат данных
Модель ожидает изображения размером 256x256 пикселей в формате RGB. Для обучения используется датасет с парами изображений и метками схожести.
Примеры использования
1. Создание и тестирование модели
from model import create_similarity_model
model = create_similarity_model(
model_type="cnn",
input_size=(256, 256),
hidden_channels=32,
num_blocks=3,
)
2. Использование функции потерь
from model import SimilarityLoss
loss_fn = SimilarityLoss()
pred = torch.tensor([[0.8], [0.2]])
target = torch.tensor([[1.0], [0.0]])
loss = loss_fn(pred, target)
3. Расчет метрик
metrics = loss_fn.compute_metrics(pred, target)
print(f"Accuracy: {metrics['accuracy']:.4f}")
print(f"F1-score: {metrics['f1']:.4f}")
Требования
- Python 3.8+
- PyTorch 1.9+
- torchvision
- Pillow
- numpy
Лицензия
MIT