Files
autopilot/models/SiaN-similarity
2026-04-04 17:49:31 +03:00
..
2026-04-04 17:49:31 +03:00
2026-03-03 21:42:23 +03:00
2026-03-22 14:29:00 +03:00
2026-03-03 21:42:23 +03:00
2026-03-22 14:29:00 +03:00
2026-03-03 21:42:23 +03:00
2026-03-03 21:42:23 +03:00
2026-04-04 17:49:31 +03:00
2026-03-22 14:29:00 +03:00
2026-03-22 14:29:00 +03:00

SiaN-Similarity: Модель для оценки схожести изображений

Модель для оценки схожести между двумя изображениями 256x256. Возвращает значение от 0 до 1, где 1 означает полную схожесть, 0 - полное различие.

Архитектура модели

Модель основана на CNN с residual блоками:

  • Общий энкодер для обоих изображений
  • Residual blocks с batch normalization
  • Слой слияния признаков
  • Регрессионная голова с сигмоидой на выходе

Использование

Установка зависимостей

pip install torch torchvision pillow

Быстрый старт

import torch
from model import SimilarityCNN

# Создание модели
model = SimilarityCNN(
    input_channels=3,
    hidden_channels=64,
    num_blocks=4,
    dropout_rate=0.3,
    use_batch_norm=True,
)

# Предсказание схожести
img1 = torch.randn(1, 3, 256, 256)  # Изображение 1
img2 = torch.randn(1, 3, 256, 256)  # Изображение 2

similarity = model.predict_similarity(img1, img2)
print(f"Схожесть: {similarity.item():.4f}")

Обучение модели

python train_similarity.py \
    --data_dir "путь/к/данным" \
    --batch_size 32 \
    --epochs 100 \
    --learning_rate 2e-4 \
    --output_dir "runs/similarity"

Предсказание на новых изображениях

python predict.py \
    --image1 "путь/к/изображению1.png" \
    --image2 "путь/к/изображению2.png" \
    --checkpoint "runs/similarity/checkpoints/best_model.pt"

Структура проекта

SiaN-similarity/
├── model.py              # Основная модель
├── dataloader.py         # Даталоадер для обучения
├── train_similarity.py   # Скрипт для обучения
├── predict.py            # Скрипт для предсказания
├── train.py              # Оригинальный тренировочный скрипт
└── README.md             # Этот файл

Конфигурация модели

Параметры по умолчанию:

  • input_channels: 3 (RGB)
  • hidden_channels: 64
  • num_blocks: 4
  • dropout_rate: 0.3
  • use_batch_norm: True
  • image_size: (256, 256)

Формат данных

Модель ожидает изображения размером 256x256 пикселей в формате RGB. Для обучения используется датасет с парами изображений и метками схожести.

Примеры использования

1. Создание и тестирование модели

from model import create_similarity_model

model = create_similarity_model(
    model_type="cnn",
    input_size=(256, 256),
    hidden_channels=32,
    num_blocks=3,
)

2. Использование функции потерь

from model import SimilarityLoss

loss_fn = SimilarityLoss()
pred = torch.tensor([[0.8], [0.2]])
target = torch.tensor([[1.0], [0.0]])
loss = loss_fn(pred, target)

3. Расчет метрик

metrics = loss_fn.compute_metrics(pred, target)
print(f"Accuracy: {metrics['accuracy']:.4f}")
print(f"F1-score: {metrics['f1']:.4f}")

Требования

  • Python 3.8+
  • PyTorch 1.9+
  • torchvision
  • Pillow
  • numpy

Лицензия

MIT