Files
autopilot/models/SiaN-similarity/example.py

217 lines
7.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
Пример использования модели оценки схожести с даталоадером.
"""
import torch
from dataloader import YaGoDataset, create_data_loaders
from model import SimilarityCNN, SimilarityLoss
def main():
"""Основной пример использования."""
print("ПРИМЕР ИСПОЛЬЗОВАНИЯ МОДЕЛИ СХОЖЕСТИ С ДАТАЛОАДЕРОМ")
print("=" * 60)
# Конфигурация
config = {
"data_dir": r"C:\Users\admin\Projects\autopilot\datasets\ya_go_maps\images",
"batch_size": 4,
"image_size": (256, 256),
"train_split": 0.8,
"num_workers": 0,
}
# Устройство
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Устройство: {device}")
# 1. Создание датасета
print("\n1. СОЗДАНИЕ ДАТАСЕТА")
print("-" * 40)
dataset = YaGoDataset(
root_dir=config["data_dir"],
augment=False,
image_size=config["image_size"],
)
print(f"Размер датасета: {len(dataset)} пар изображений")
# Получение примера из датасета
sample = dataset[0]
print(f"\nПример из датасета:")
print(f" Google image shape: {sample['google_img'].shape}")
print(f" Yandex image shape: {sample['yandex_img'].shape}")
print(f" Same domain: {sample['same_domain']}")
print(f" Index: {sample['idx'].item()}")
# 2. Создание даталоадеров
print("\n2. СОЗДАНИЕ ДАТАЛОАДЕРОВ")
print("-" * 40)
train_loader, val_loader = create_data_loaders(
root_dir=config["data_dir"],
batch_size=config["batch_size"],
train_split=config["train_split"],
num_workers=config["num_workers"],
image_size=config["image_size"],
augment_train=True,
augment_val=False,
device=device,
)
print(f"Train batches: {len(train_loader)}")
print(f"Val batches: {len(val_loader)}")
# 3. Создание модели
print("\n3. СОЗДАНИЕ МОДЕЛИ")
print("-" * 40)
model = SimilarityCNN(
input_channels=3,
hidden_channels=64,
num_blocks=4,
dropout_rate=0.3,
use_batch_norm=True,
).to(device)
print(f"Параметры модели: {sum(p.numel() for p in model.parameters()):,}")
# 4. Тестирование на одном батче
print("\n4. ТЕСТИРОВАНИЕ НА ОДНОМ БАТЧЕ")
print("-" * 40)
# Получаем батч из train_loader
for batch in train_loader:
google_img = batch["google_img"].to(device)
yandex_img = batch["yandex_img"].to(device)
same_domain = batch["same_domain"].float().to(device).unsqueeze(1)
print(f"Batch size: {google_img.shape[0]}")
print(f"Image shape: {google_img.shape[1:]}")
print(f"Same domain labels: {same_domain.squeeze().tolist()}")
# Предсказание схожести
with torch.no_grad():
predictions = model.predict_similarity(google_img, yandex_img)
print(f"\nПредсказания схожести:")
for i in range(len(predictions)):
print(
f" Sample {i}: {predictions[i].item():.4f} (target: {same_domain[i].item():.1f})"
)
# Расчет потерь
loss_fn = SimilarityLoss().to(device)
loss = loss_fn(predictions, same_domain)
print(f"\nПотери на батче: {loss.item():.4f}")
# Расчет метрик
metrics = loss_fn.compute_metrics(predictions, same_domain)
print("\nМетрики на батче:")
for key, value in metrics.items():
print(f" {key}: {value:.4f}")
break # Только первый батч
# 5. Обучение на одном эпохе (демонстрация)
print("\n5. ДЕМОНСТРАЦИЯ ОБУЧЕНИЯ НА ОДНОЙ ЭПОХЕ")
print("-" * 40)
optimizer = torch.optim.Adam(model.parameters(), lr=2e-4)
model.train()
total_loss = 0
total_samples = 0
for batch_idx, batch in enumerate(train_loader):
if batch_idx >= 3: # Ограничиваем 3 батчами для демонстрации
break
google_img = batch["google_img"].to(device)
yandex_img = batch["yandex_img"].to(device)
same_domain = batch["same_domain"].float().to(device).unsqueeze(1)
optimizer.zero_grad()
predictions = model(google_img, yandex_img)
loss = loss_fn(predictions, same_domain)
loss.backward()
optimizer.step()
batch_loss = loss.item() * google_img.size(0)
total_loss += batch_loss
total_samples += google_img.size(0)
print(f"Batch {batch_idx + 1}: loss = {loss.item():.4f}")
avg_loss = total_loss / total_samples
print(f"\nСредние потери за 3 батча: {avg_loss:.4f}")
# 6. Валидация
print("\n6. ВАЛИДАЦИЯ")
print("-" * 40)
model.eval()
val_loss = 0
val_samples = 0
with torch.no_grad():
for batch_idx, batch in enumerate(val_loader):
if batch_idx >= 2: # Ограничиваем 2 батчами для демонстрации
break
google_img = batch["google_img"].to(device)
yandex_img = batch["yandex_img"].to(device)
same_domain = batch["same_domain"].float().to(device).unsqueeze(1)
predictions = model.predict_similarity(google_img, yandex_img)
loss = loss_fn(predictions, same_domain)
val_loss += loss.item() * google_img.size(0)
val_samples += google_img.size(0)
print(f"Val batch {batch_idx + 1}: loss = {loss.item():.4f}")
avg_val_loss = val_loss / val_samples
print(f"\nСредние потери на валидации: {avg_val_loss:.4f}")
# 7. Пример использования для отдельных изображений
print("\n7. ПРИМЕР ДЛЯ ОТДЕЛЬНЫХ ИЗОБРАЖЕНИЙ")
print("-" * 40)
# Берем два примера из датасета
sample1 = dataset[0]
sample2 = dataset[1]
# Подготавливаем тензоры
img1_1 = sample1["google_img"].unsqueeze(0).to(device)
img1_2 = sample1["yandex_img"].unsqueeze(0).to(device)
img2_1 = sample2["google_img"].unsqueeze(0).to(device)
img2_2 = sample2["yandex_img"].unsqueeze(0).to(device)
# Предсказания
with torch.no_grad():
# Сравнение пар из одного домена
sim_same1 = model.predict_similarity(img1_1, img1_2)
sim_same2 = model.predict_similarity(img2_1, img2_2)
# Сравнение пар из разных доменов
sim_diff1 = model.predict_similarity(img1_1, img2_2)
sim_diff2 = model.predict_similarity(img2_1, img1_2)
print("Сравнение пар изображений:")
print(f" Пара 1 (один домен): {sim_same1.item():.4f}")
print(f" Пара 2 (один домен): {sim_same2.item():.4f}")
print(f" Разные домены 1: {sim_diff1.item():.4f}")
print(f" Разные домены 2: {sim_diff2.item():.4f}")
print("\n" + "=" * 60)
print("ПРИМЕР ЗАВЕРШЕН")
print("=" * 60)
if __name__ == "__main__":
main()