feat: add similarity model
This commit is contained in:
192
models/SiaN-similarity/demo.py
Normal file
192
models/SiaN-similarity/demo.py
Normal file
@@ -0,0 +1,192 @@
|
||||
"""
|
||||
Демонстрационный скрипт для модели оценки схожести изображений.
|
||||
"""
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import torch
|
||||
from model import SimilarityCNN, SimilarityLoss
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
from torchvision import transforms
|
||||
|
||||
|
||||
def create_test_images():
|
||||
"""Создание тестовых изображений для демонстрации."""
|
||||
images = []
|
||||
|
||||
# Изображение 1: Красный квадрат
|
||||
img1 = Image.new("RGB", (256, 256), color="white")
|
||||
draw = ImageDraw.Draw(img1)
|
||||
draw.rectangle([50, 50, 200, 200], fill="red", outline="black", width=2)
|
||||
images.append(("Красный квадрат", img1))
|
||||
|
||||
# Изображение 2: Тот же красный квадрат (похожее)
|
||||
img2 = Image.new("RGB", (256, 256), color="white")
|
||||
draw = ImageDraw.Draw(img2)
|
||||
draw.rectangle([55, 55, 205, 205], fill="red", outline="black", width=2)
|
||||
images.append(("Похожий красный квадрат", img2))
|
||||
|
||||
# Изображение 3: Синий круг (разное)
|
||||
img3 = Image.new("RGB", (256, 256), color="white")
|
||||
draw = ImageDraw.Draw(img3)
|
||||
draw.ellipse([50, 50, 200, 200], fill="blue", outline="black", width=2)
|
||||
images.append(("Синий круг", img3))
|
||||
|
||||
# Изображение 4: Зеленый треугольник (разное)
|
||||
img4 = Image.new("RGB", (256, 256), color="white")
|
||||
draw = ImageDraw.Draw(img4)
|
||||
draw.polygon(
|
||||
[(128, 50), (50, 200), (200, 200)], fill="green", outline="black", width=2
|
||||
)
|
||||
images.append(("Зеленый треугольник", img4))
|
||||
|
||||
return images
|
||||
|
||||
|
||||
def preprocess_image(image):
|
||||
"""Преобразование PIL Image в тензор."""
|
||||
transform = transforms.Compose(
|
||||
[
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
||||
]
|
||||
)
|
||||
return transform(image).unsqueeze(0) # Добавляем batch dimension
|
||||
|
||||
|
||||
def display_results(images, similarities):
|
||||
"""Отображение результатов сравнения."""
|
||||
fig, axes = plt.subplots(2, 2, figsize=(12, 10))
|
||||
axes = axes.flatten()
|
||||
|
||||
for idx, (title, img) in enumerate(images):
|
||||
ax = axes[idx]
|
||||
ax.imshow(img)
|
||||
ax.set_title(title, fontsize=12, fontweight="bold")
|
||||
ax.axis("off")
|
||||
|
||||
plt.suptitle("Тестовые изображения", fontsize=16, fontweight="bold")
|
||||
plt.tight_layout()
|
||||
plt.show()
|
||||
|
||||
# Вывод результатов сравнения
|
||||
print("\n" + "=" * 60)
|
||||
print("РЕЗУЛЬТАТЫ СРАВНЕНИЯ ИЗОБРАЖЕНИЙ")
|
||||
print("=" * 60)
|
||||
|
||||
comparisons = [
|
||||
("Красный квадрат", "Похожий красный квадрат"),
|
||||
("Красный квадрат", "Синий круг"),
|
||||
("Красный квадрат", "Зеленый треугольник"),
|
||||
("Похожий красный квадрат", "Синий круг"),
|
||||
]
|
||||
|
||||
for i, (name1, name2) in enumerate(comparisons):
|
||||
idx1 = [idx for idx, (name, _) in enumerate(images) if name == name1][0]
|
||||
idx2 = [idx for idx, (name, _) in enumerate(images) if name == name2][0]
|
||||
|
||||
sim = similarities[idx1, idx2]
|
||||
interpretation = "ПОХОЖИ" if sim > 0.5 else "РАЗНЫЕ"
|
||||
|
||||
print(f"\n{name1} vs {name2}:")
|
||||
print(f" Схожесть: {sim:.4f}")
|
||||
print(f" Интерпретация: {interpretation}")
|
||||
print(f" Уверенность: {'Высокая' if sim > 0.7 or sim < 0.3 else 'Средняя'}")
|
||||
|
||||
|
||||
def test_loss_function():
|
||||
"""Тестирование функции потерь."""
|
||||
print("\n" + "=" * 60)
|
||||
print("ТЕСТИРОВАНИЕ ФУНКЦИИ ПОТЕРЬ")
|
||||
print("=" * 60)
|
||||
|
||||
loss_fn = SimilarityLoss()
|
||||
|
||||
# Тестовые данные
|
||||
predictions = torch.tensor([[0.9], [0.1], [0.7], [0.3]])
|
||||
targets = torch.tensor([[1.0], [0.0], [1.0], [0.0]])
|
||||
|
||||
# Расчет потерь
|
||||
loss = loss_fn(predictions, targets)
|
||||
print(f"\nПотери: {loss.item():.4f}")
|
||||
|
||||
# Расчет метрик
|
||||
metrics = loss_fn.compute_metrics(predictions, targets)
|
||||
print("\nМетрики:")
|
||||
for key, value in metrics.items():
|
||||
print(f" {key}: {value:.4f}")
|
||||
|
||||
|
||||
def main():
|
||||
"""Основная функция демонстрации."""
|
||||
print("ДЕМОНСТРАЦИЯ МОДЕЛИ ОЦЕНКИ СХОЖЕСТИ ИЗОБРАЖЕНИЙ")
|
||||
print("=" * 60)
|
||||
|
||||
# Создание модели
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
print(f"\nУстройство: {device}")
|
||||
|
||||
model = SimilarityCNN(
|
||||
input_channels=3,
|
||||
hidden_channels=64,
|
||||
num_blocks=4,
|
||||
dropout_rate=0.3,
|
||||
use_batch_norm=True,
|
||||
).to(device)
|
||||
|
||||
print(f"Параметры модели: {sum(p.numel() for p in model.parameters()):,}")
|
||||
|
||||
# Создание тестовых изображений
|
||||
print("\nСоздание тестовых изображений...")
|
||||
test_images = create_test_images()
|
||||
|
||||
# Преобразование изображений в тензоры
|
||||
tensors = []
|
||||
for name, img in test_images:
|
||||
tensor = preprocess_image(img).to(device)
|
||||
tensors.append(tensor)
|
||||
|
||||
# Расчет схожести между всеми парами изображений
|
||||
print("\nРасчет схожести между изображениями...")
|
||||
n_images = len(test_images)
|
||||
similarity_matrix = np.zeros((n_images, n_images))
|
||||
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
for i in range(n_images):
|
||||
for j in range(n_images):
|
||||
if i <= j: # Рассчитываем только верхний треугольник
|
||||
sim = model.predict_similarity(tensors[i], tensors[j])
|
||||
similarity_matrix[i, j] = sim.item()
|
||||
similarity_matrix[j, i] = sim.item() # Симметричная матрица
|
||||
|
||||
# Отображение результатов
|
||||
display_results(test_images, similarity_matrix)
|
||||
|
||||
# Тестирование функции потерь
|
||||
test_loss_function()
|
||||
|
||||
# Дополнительная информация
|
||||
print("\n" + "=" * 60)
|
||||
print("ИНФОРМАЦИЯ О МОДЕЛИ")
|
||||
print("=" * 60)
|
||||
print("\nАрхитектура модели:")
|
||||
print("-" * 40)
|
||||
print("Вход: два изображения 256x256x3")
|
||||
print("Энкодер: CNN с residual блоками")
|
||||
print("Слой слияния: объединение признаков")
|
||||
print("Выход: значение схожести [0, 1]")
|
||||
print("\nИнтерпретация результатов:")
|
||||
print("- 0.8-1.0: Очень похожи")
|
||||
print("- 0.6-0.8: Похожи")
|
||||
print("- 0.4-0.6: Нейтрально")
|
||||
print("- 0.2-0.4: Разные")
|
||||
print("- 0.0-0.2: Совершенно разные")
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("ДЕМОНСТРАЦИЯ ЗАВЕРШЕНА")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user