Files
autopilot/models/SiaN-similarity/demo.py

193 lines
7.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
Демонстрационный скрипт для модели оценки схожести изображений.
"""
import matplotlib.pyplot as plt
import numpy as np
import torch
from model import SimilarityCNN, SimilarityLoss
from PIL import Image, ImageDraw, ImageFont
from torchvision import transforms
def create_test_images():
"""Создание тестовых изображений для демонстрации."""
images = []
# Изображение 1: Красный квадрат
img1 = Image.new("RGB", (256, 256), color="white")
draw = ImageDraw.Draw(img1)
draw.rectangle([50, 50, 200, 200], fill="red", outline="black", width=2)
images.append(("Красный квадрат", img1))
# Изображение 2: Тот же красный квадрат (похожее)
img2 = Image.new("RGB", (256, 256), color="white")
draw = ImageDraw.Draw(img2)
draw.rectangle([55, 55, 205, 205], fill="red", outline="black", width=2)
images.append(("Похожий красный квадрат", img2))
# Изображение 3: Синий круг (разное)
img3 = Image.new("RGB", (256, 256), color="white")
draw = ImageDraw.Draw(img3)
draw.ellipse([50, 50, 200, 200], fill="blue", outline="black", width=2)
images.append(("Синий круг", img3))
# Изображение 4: Зеленый треугольник (разное)
img4 = Image.new("RGB", (256, 256), color="white")
draw = ImageDraw.Draw(img4)
draw.polygon(
[(128, 50), (50, 200), (200, 200)], fill="green", outline="black", width=2
)
images.append(("Зеленый треугольник", img4))
return images
def preprocess_image(image):
"""Преобразование PIL Image в тензор."""
transform = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
)
return transform(image).unsqueeze(0) # Добавляем batch dimension
def display_results(images, similarities):
"""Отображение результатов сравнения."""
fig, axes = plt.subplots(2, 2, figsize=(12, 10))
axes = axes.flatten()
for idx, (title, img) in enumerate(images):
ax = axes[idx]
ax.imshow(img)
ax.set_title(title, fontsize=12, fontweight="bold")
ax.axis("off")
plt.suptitle("Тестовые изображения", fontsize=16, fontweight="bold")
plt.tight_layout()
plt.show()
# Вывод результатов сравнения
print("\n" + "=" * 60)
print("РЕЗУЛЬТАТЫ СРАВНЕНИЯ ИЗОБРАЖЕНИЙ")
print("=" * 60)
comparisons = [
("Красный квадрат", "Похожий красный квадрат"),
("Красный квадрат", "Синий круг"),
("Красный квадрат", "Зеленый треугольник"),
("Похожий красный квадрат", "Синий круг"),
]
for i, (name1, name2) in enumerate(comparisons):
idx1 = [idx for idx, (name, _) in enumerate(images) if name == name1][0]
idx2 = [idx for idx, (name, _) in enumerate(images) if name == name2][0]
sim = similarities[idx1, idx2]
interpretation = "ПОХОЖИ" if sim > 0.5 else "РАЗНЫЕ"
print(f"\n{name1} vs {name2}:")
print(f" Схожесть: {sim:.4f}")
print(f" Интерпретация: {interpretation}")
print(f" Уверенность: {'Высокая' if sim > 0.7 or sim < 0.3 else 'Средняя'}")
def test_loss_function():
"""Тестирование функции потерь."""
print("\n" + "=" * 60)
print("ТЕСТИРОВАНИЕ ФУНКЦИИ ПОТЕРЬ")
print("=" * 60)
loss_fn = SimilarityLoss()
# Тестовые данные
predictions = torch.tensor([[0.9], [0.1], [0.7], [0.3]])
targets = torch.tensor([[1.0], [0.0], [1.0], [0.0]])
# Расчет потерь
loss = loss_fn(predictions, targets)
print(f"\nПотери: {loss.item():.4f}")
# Расчет метрик
metrics = loss_fn.compute_metrics(predictions, targets)
print("\nМетрики:")
for key, value in metrics.items():
print(f" {key}: {value:.4f}")
def main():
"""Основная функция демонстрации."""
print("ДЕМОНСТРАЦИЯ МОДЕЛИ ОЦЕНКИ СХОЖЕСТИ ИЗОБРАЖЕНИЙ")
print("=" * 60)
# Создание модели
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"\nУстройство: {device}")
model = SimilarityCNN(
input_channels=3,
hidden_channels=64,
num_blocks=4,
dropout_rate=0.3,
use_batch_norm=True,
).to(device)
print(f"Параметры модели: {sum(p.numel() for p in model.parameters()):,}")
# Создание тестовых изображений
print("\nСоздание тестовых изображений...")
test_images = create_test_images()
# Преобразование изображений в тензоры
tensors = []
for name, img in test_images:
tensor = preprocess_image(img).to(device)
tensors.append(tensor)
# Расчет схожести между всеми парами изображений
print("\nРасчет схожести между изображениями...")
n_images = len(test_images)
similarity_matrix = np.zeros((n_images, n_images))
model.eval()
with torch.no_grad():
for i in range(n_images):
for j in range(n_images):
if i <= j: # Рассчитываем только верхний треугольник
sim = model.predict_similarity(tensors[i], tensors[j])
similarity_matrix[i, j] = sim.item()
similarity_matrix[j, i] = sim.item() # Симметричная матрица
# Отображение результатов
display_results(test_images, similarity_matrix)
# Тестирование функции потерь
test_loss_function()
# Дополнительная информация
print("\n" + "=" * 60)
print("ИНФОРМАЦИЯ О МОДЕЛИ")
print("=" * 60)
print("\nАрхитектура модели:")
print("-" * 40)
print("Вход: два изображения 256x256x3")
print("Энкодер: CNN с residual блоками")
print("Слой слияния: объединение признаков")
print("Выход: значение схожести [0, 1]")
print("\nИнтерпретация результатов:")
print("- 0.8-1.0: Очень похожи")
print("- 0.6-0.8: Похожи")
print("- 0.4-0.6: Нейтрально")
print("- 0.2-0.4: Разные")
print("- 0.0-0.2: Совершенно разные")
print("\n" + "=" * 60)
print("ДЕМОНСТРАЦИЯ ЗАВЕРШЕНА")
print("=" * 60)
if __name__ == "__main__":
main()