feat: add similarity model

This commit is contained in:
2026-03-03 21:42:23 +03:00
parent 1de150b386
commit 43cd4222bc
7 changed files with 1801 additions and 0 deletions

View File

@@ -0,0 +1,192 @@
"""
Демонстрационный скрипт для модели оценки схожести изображений.
"""
import matplotlib.pyplot as plt
import numpy as np
import torch
from model import SimilarityCNN, SimilarityLoss
from PIL import Image, ImageDraw, ImageFont
from torchvision import transforms
def create_test_images():
"""Создание тестовых изображений для демонстрации."""
images = []
# Изображение 1: Красный квадрат
img1 = Image.new("RGB", (256, 256), color="white")
draw = ImageDraw.Draw(img1)
draw.rectangle([50, 50, 200, 200], fill="red", outline="black", width=2)
images.append(("Красный квадрат", img1))
# Изображение 2: Тот же красный квадрат (похожее)
img2 = Image.new("RGB", (256, 256), color="white")
draw = ImageDraw.Draw(img2)
draw.rectangle([55, 55, 205, 205], fill="red", outline="black", width=2)
images.append(("Похожий красный квадрат", img2))
# Изображение 3: Синий круг (разное)
img3 = Image.new("RGB", (256, 256), color="white")
draw = ImageDraw.Draw(img3)
draw.ellipse([50, 50, 200, 200], fill="blue", outline="black", width=2)
images.append(("Синий круг", img3))
# Изображение 4: Зеленый треугольник (разное)
img4 = Image.new("RGB", (256, 256), color="white")
draw = ImageDraw.Draw(img4)
draw.polygon(
[(128, 50), (50, 200), (200, 200)], fill="green", outline="black", width=2
)
images.append(("Зеленый треугольник", img4))
return images
def preprocess_image(image):
"""Преобразование PIL Image в тензор."""
transform = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
)
return transform(image).unsqueeze(0) # Добавляем batch dimension
def display_results(images, similarities):
"""Отображение результатов сравнения."""
fig, axes = plt.subplots(2, 2, figsize=(12, 10))
axes = axes.flatten()
for idx, (title, img) in enumerate(images):
ax = axes[idx]
ax.imshow(img)
ax.set_title(title, fontsize=12, fontweight="bold")
ax.axis("off")
plt.suptitle("Тестовые изображения", fontsize=16, fontweight="bold")
plt.tight_layout()
plt.show()
# Вывод результатов сравнения
print("\n" + "=" * 60)
print("РЕЗУЛЬТАТЫ СРАВНЕНИЯ ИЗОБРАЖЕНИЙ")
print("=" * 60)
comparisons = [
("Красный квадрат", "Похожий красный квадрат"),
("Красный квадрат", "Синий круг"),
("Красный квадрат", "Зеленый треугольник"),
("Похожий красный квадрат", "Синий круг"),
]
for i, (name1, name2) in enumerate(comparisons):
idx1 = [idx for idx, (name, _) in enumerate(images) if name == name1][0]
idx2 = [idx for idx, (name, _) in enumerate(images) if name == name2][0]
sim = similarities[idx1, idx2]
interpretation = "ПОХОЖИ" if sim > 0.5 else "РАЗНЫЕ"
print(f"\n{name1} vs {name2}:")
print(f" Схожесть: {sim:.4f}")
print(f" Интерпретация: {interpretation}")
print(f" Уверенность: {'Высокая' if sim > 0.7 or sim < 0.3 else 'Средняя'}")
def test_loss_function():
"""Тестирование функции потерь."""
print("\n" + "=" * 60)
print("ТЕСТИРОВАНИЕ ФУНКЦИИ ПОТЕРЬ")
print("=" * 60)
loss_fn = SimilarityLoss()
# Тестовые данные
predictions = torch.tensor([[0.9], [0.1], [0.7], [0.3]])
targets = torch.tensor([[1.0], [0.0], [1.0], [0.0]])
# Расчет потерь
loss = loss_fn(predictions, targets)
print(f"\nПотери: {loss.item():.4f}")
# Расчет метрик
metrics = loss_fn.compute_metrics(predictions, targets)
print("\nМетрики:")
for key, value in metrics.items():
print(f" {key}: {value:.4f}")
def main():
"""Основная функция демонстрации."""
print("ДЕМОНСТРАЦИЯ МОДЕЛИ ОЦЕНКИ СХОЖЕСТИ ИЗОБРАЖЕНИЙ")
print("=" * 60)
# Создание модели
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"\nУстройство: {device}")
model = SimilarityCNN(
input_channels=3,
hidden_channels=64,
num_blocks=4,
dropout_rate=0.3,
use_batch_norm=True,
).to(device)
print(f"Параметры модели: {sum(p.numel() for p in model.parameters()):,}")
# Создание тестовых изображений
print("\nСоздание тестовых изображений...")
test_images = create_test_images()
# Преобразование изображений в тензоры
tensors = []
for name, img in test_images:
tensor = preprocess_image(img).to(device)
tensors.append(tensor)
# Расчет схожести между всеми парами изображений
print("\nРасчет схожести между изображениями...")
n_images = len(test_images)
similarity_matrix = np.zeros((n_images, n_images))
model.eval()
with torch.no_grad():
for i in range(n_images):
for j in range(n_images):
if i <= j: # Рассчитываем только верхний треугольник
sim = model.predict_similarity(tensors[i], tensors[j])
similarity_matrix[i, j] = sim.item()
similarity_matrix[j, i] = sim.item() # Симметричная матрица
# Отображение результатов
display_results(test_images, similarity_matrix)
# Тестирование функции потерь
test_loss_function()
# Дополнительная информация
print("\n" + "=" * 60)
print("ИНФОРМАЦИЯ О МОДЕЛИ")
print("=" * 60)
print("\nАрхитектура модели:")
print("-" * 40)
print("Вход: два изображения 256x256x3")
print("Энкодер: CNN с residual блоками")
print("Слой слияния: объединение признаков")
print("Выход: значение схожести [0, 1]")
print("\nИнтерпретация результатов:")
print("- 0.8-1.0: Очень похожи")
print("- 0.6-0.8: Похожи")
print("- 0.4-0.6: Нейтрально")
print("- 0.2-0.4: Разные")
print("- 0.0-0.2: Совершенно разные")
print("\n" + "=" * 60)
print("ДЕМОНСТРАЦИЯ ЗАВЕРШЕНА")
print("=" * 60)
if __name__ == "__main__":
main()