193 lines
7.3 KiB
Python
193 lines
7.3 KiB
Python
"""
|
||
Демонстрационный скрипт для модели оценки схожести изображений.
|
||
"""
|
||
|
||
import matplotlib.pyplot as plt
|
||
import numpy as np
|
||
import torch
|
||
from model import SimilarityCNN, SimilarityLoss
|
||
from PIL import Image, ImageDraw, ImageFont
|
||
from torchvision import transforms
|
||
|
||
|
||
def create_test_images():
|
||
"""Создание тестовых изображений для демонстрации."""
|
||
images = []
|
||
|
||
# Изображение 1: Красный квадрат
|
||
img1 = Image.new("RGB", (256, 256), color="white")
|
||
draw = ImageDraw.Draw(img1)
|
||
draw.rectangle([50, 50, 200, 200], fill="red", outline="black", width=2)
|
||
images.append(("Красный квадрат", img1))
|
||
|
||
# Изображение 2: Тот же красный квадрат (похожее)
|
||
img2 = Image.new("RGB", (256, 256), color="white")
|
||
draw = ImageDraw.Draw(img2)
|
||
draw.rectangle([55, 55, 205, 205], fill="red", outline="black", width=2)
|
||
images.append(("Похожий красный квадрат", img2))
|
||
|
||
# Изображение 3: Синий круг (разное)
|
||
img3 = Image.new("RGB", (256, 256), color="white")
|
||
draw = ImageDraw.Draw(img3)
|
||
draw.ellipse([50, 50, 200, 200], fill="blue", outline="black", width=2)
|
||
images.append(("Синий круг", img3))
|
||
|
||
# Изображение 4: Зеленый треугольник (разное)
|
||
img4 = Image.new("RGB", (256, 256), color="white")
|
||
draw = ImageDraw.Draw(img4)
|
||
draw.polygon(
|
||
[(128, 50), (50, 200), (200, 200)], fill="green", outline="black", width=2
|
||
)
|
||
images.append(("Зеленый треугольник", img4))
|
||
|
||
return images
|
||
|
||
|
||
def preprocess_image(image):
|
||
"""Преобразование PIL Image в тензор."""
|
||
transform = transforms.Compose(
|
||
[
|
||
transforms.ToTensor(),
|
||
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
||
]
|
||
)
|
||
return transform(image).unsqueeze(0) # Добавляем batch dimension
|
||
|
||
|
||
def display_results(images, similarities):
|
||
"""Отображение результатов сравнения."""
|
||
fig, axes = plt.subplots(2, 2, figsize=(12, 10))
|
||
axes = axes.flatten()
|
||
|
||
for idx, (title, img) in enumerate(images):
|
||
ax = axes[idx]
|
||
ax.imshow(img)
|
||
ax.set_title(title, fontsize=12, fontweight="bold")
|
||
ax.axis("off")
|
||
|
||
plt.suptitle("Тестовые изображения", fontsize=16, fontweight="bold")
|
||
plt.tight_layout()
|
||
plt.show()
|
||
|
||
# Вывод результатов сравнения
|
||
print("\n" + "=" * 60)
|
||
print("РЕЗУЛЬТАТЫ СРАВНЕНИЯ ИЗОБРАЖЕНИЙ")
|
||
print("=" * 60)
|
||
|
||
comparisons = [
|
||
("Красный квадрат", "Похожий красный квадрат"),
|
||
("Красный квадрат", "Синий круг"),
|
||
("Красный квадрат", "Зеленый треугольник"),
|
||
("Похожий красный квадрат", "Синий круг"),
|
||
]
|
||
|
||
for i, (name1, name2) in enumerate(comparisons):
|
||
idx1 = [idx for idx, (name, _) in enumerate(images) if name == name1][0]
|
||
idx2 = [idx for idx, (name, _) in enumerate(images) if name == name2][0]
|
||
|
||
sim = similarities[idx1, idx2]
|
||
interpretation = "ПОХОЖИ" if sim > 0.5 else "РАЗНЫЕ"
|
||
|
||
print(f"\n{name1} vs {name2}:")
|
||
print(f" Схожесть: {sim:.4f}")
|
||
print(f" Интерпретация: {interpretation}")
|
||
print(f" Уверенность: {'Высокая' if sim > 0.7 or sim < 0.3 else 'Средняя'}")
|
||
|
||
|
||
def test_loss_function():
|
||
"""Тестирование функции потерь."""
|
||
print("\n" + "=" * 60)
|
||
print("ТЕСТИРОВАНИЕ ФУНКЦИИ ПОТЕРЬ")
|
||
print("=" * 60)
|
||
|
||
loss_fn = SimilarityLoss()
|
||
|
||
# Тестовые данные
|
||
predictions = torch.tensor([[0.9], [0.1], [0.7], [0.3]])
|
||
targets = torch.tensor([[1.0], [0.0], [1.0], [0.0]])
|
||
|
||
# Расчет потерь
|
||
loss = loss_fn(predictions, targets)
|
||
print(f"\nПотери: {loss.item():.4f}")
|
||
|
||
# Расчет метрик
|
||
metrics = loss_fn.compute_metrics(predictions, targets)
|
||
print("\nМетрики:")
|
||
for key, value in metrics.items():
|
||
print(f" {key}: {value:.4f}")
|
||
|
||
|
||
def main():
|
||
"""Основная функция демонстрации."""
|
||
print("ДЕМОНСТРАЦИЯ МОДЕЛИ ОЦЕНКИ СХОЖЕСТИ ИЗОБРАЖЕНИЙ")
|
||
print("=" * 60)
|
||
|
||
# Создание модели
|
||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||
print(f"\nУстройство: {device}")
|
||
|
||
model = SimilarityCNN(
|
||
input_channels=3,
|
||
hidden_channels=64,
|
||
num_blocks=4,
|
||
dropout_rate=0.3,
|
||
use_batch_norm=True,
|
||
).to(device)
|
||
|
||
print(f"Параметры модели: {sum(p.numel() for p in model.parameters()):,}")
|
||
|
||
# Создание тестовых изображений
|
||
print("\nСоздание тестовых изображений...")
|
||
test_images = create_test_images()
|
||
|
||
# Преобразование изображений в тензоры
|
||
tensors = []
|
||
for name, img in test_images:
|
||
tensor = preprocess_image(img).to(device)
|
||
tensors.append(tensor)
|
||
|
||
# Расчет схожести между всеми парами изображений
|
||
print("\nРасчет схожести между изображениями...")
|
||
n_images = len(test_images)
|
||
similarity_matrix = np.zeros((n_images, n_images))
|
||
|
||
model.eval()
|
||
with torch.no_grad():
|
||
for i in range(n_images):
|
||
for j in range(n_images):
|
||
if i <= j: # Рассчитываем только верхний треугольник
|
||
sim = model.predict_similarity(tensors[i], tensors[j])
|
||
similarity_matrix[i, j] = sim.item()
|
||
similarity_matrix[j, i] = sim.item() # Симметричная матрица
|
||
|
||
# Отображение результатов
|
||
display_results(test_images, similarity_matrix)
|
||
|
||
# Тестирование функции потерь
|
||
test_loss_function()
|
||
|
||
# Дополнительная информация
|
||
print("\n" + "=" * 60)
|
||
print("ИНФОРМАЦИЯ О МОДЕЛИ")
|
||
print("=" * 60)
|
||
print("\nАрхитектура модели:")
|
||
print("-" * 40)
|
||
print("Вход: два изображения 256x256x3")
|
||
print("Энкодер: CNN с residual блоками")
|
||
print("Слой слияния: объединение признаков")
|
||
print("Выход: значение схожести [0, 1]")
|
||
print("\nИнтерпретация результатов:")
|
||
print("- 0.8-1.0: Очень похожи")
|
||
print("- 0.6-0.8: Похожи")
|
||
print("- 0.4-0.6: Нейтрально")
|
||
print("- 0.2-0.4: Разные")
|
||
print("- 0.0-0.2: Совершенно разные")
|
||
|
||
print("\n" + "=" * 60)
|
||
print("ДЕМОНСТРАЦИЯ ЗАВЕРШЕНА")
|
||
print("=" * 60)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|