Files
autopilot/models/SiaN-similarity/README.md

131 lines
3.8 KiB
Markdown
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# SiaN-Similarity: Модель для оценки схожести изображений
Модель для оценки схожести между двумя изображениями 256x256. Возвращает значение от 0 до 1, где 1 означает полную схожесть, 0 - полное различие.
## Архитектура модели
Модель основана на CNN с residual блоками:
- Общий энкодер для обоих изображений
- Residual blocks с batch normalization
- Слой слияния признаков
- Регрессионная голова с сигмоидой на выходе
## Использование
### Установка зависимостей
```bash
pip install torch torchvision pillow
```
### Быстрый старт
```python
import torch
from model import SimilarityCNN
# Создание модели
model = SimilarityCNN(
input_channels=3,
hidden_channels=64,
num_blocks=4,
dropout_rate=0.3,
use_batch_norm=True,
)
# Предсказание схожести
img1 = torch.randn(1, 3, 256, 256) # Изображение 1
img2 = torch.randn(1, 3, 256, 256) # Изображение 2
similarity = model.predict_similarity(img1, img2)
print(f"Схожесть: {similarity.item():.4f}")
```
### Обучение модели
```bash
python train_similarity.py \
--data_dir "путь/к/данным" \
--batch_size 32 \
--epochs 100 \
--learning_rate 2e-4 \
--output_dir "runs/similarity"
```
### Предсказание на новых изображениях
```bash
python predict.py \
--image1 "путь/к/изображению1.png" \
--image2 "путь/к/изображению2.png" \
--checkpoint "runs/similarity/checkpoints/best_model.pt"
```
## Структура проекта
```
SiaN-similarity/
├── model.py # Основная модель
├── dataloader.py # Даталоадер для обучения
├── train_similarity.py # Скрипт для обучения
├── predict.py # Скрипт для предсказания
├── train.py # Оригинальный тренировочный скрипт
└── README.md # Этот файл
```
## Конфигурация модели
Параметры по умолчанию:
- `input_channels`: 3 (RGB)
- `hidden_channels`: 64
- `num_blocks`: 4
- `dropout_rate`: 0.3
- `use_batch_norm`: True
- `image_size`: (256, 256)
## Формат данных
Модель ожидает изображения размером 256x256 пикселей в формате RGB.
Для обучения используется датасет с парами изображений и метками схожести.
## Примеры использования
### 1. Создание и тестирование модели
```python
from model import create_similarity_model
model = create_similarity_model(
model_type="cnn",
input_size=(256, 256),
hidden_channels=32,
num_blocks=3,
)
```
### 2. Использование функции потерь
```python
from model import SimilarityLoss
loss_fn = SimilarityLoss()
pred = torch.tensor([[0.8], [0.2]])
target = torch.tensor([[1.0], [0.0]])
loss = loss_fn(pred, target)
```
### 3. Расчет метрик
```python
metrics = loss_fn.compute_metrics(pred, target)
print(f"Accuracy: {metrics['accuracy']:.4f}")
print(f"F1-score: {metrics['f1']:.4f}")
```
## Требования
- Python 3.8+
- PyTorch 1.9+
- torchvision
- Pillow
- numpy
## Лицензия
MIT