217 lines
7.7 KiB
Python
217 lines
7.7 KiB
Python
"""
|
||
Пример использования модели оценки схожести с даталоадером.
|
||
"""
|
||
|
||
import torch
|
||
from dataloader import YaGoDataset, create_data_loaders
|
||
from model import SimilarityCNN, SimilarityLoss
|
||
|
||
|
||
def main():
|
||
"""Основной пример использования."""
|
||
print("ПРИМЕР ИСПОЛЬЗОВАНИЯ МОДЕЛИ СХОЖЕСТИ С ДАТАЛОАДЕРОМ")
|
||
print("=" * 60)
|
||
|
||
# Конфигурация
|
||
config = {
|
||
"data_dir": r"C:\Users\admin\Projects\autopilot\datasets\ya_go_maps\images",
|
||
"batch_size": 4,
|
||
"image_size": (256, 256),
|
||
"train_split": 0.8,
|
||
"num_workers": 0,
|
||
}
|
||
|
||
# Устройство
|
||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||
print(f"Устройство: {device}")
|
||
|
||
# 1. Создание датасета
|
||
print("\n1. СОЗДАНИЕ ДАТАСЕТА")
|
||
print("-" * 40)
|
||
|
||
dataset = YaGoDataset(
|
||
root_dir=config["data_dir"],
|
||
augment=False,
|
||
image_size=config["image_size"],
|
||
)
|
||
|
||
print(f"Размер датасета: {len(dataset)} пар изображений")
|
||
|
||
# Получение примера из датасета
|
||
sample = dataset[0]
|
||
print(f"\nПример из датасета:")
|
||
print(f" Google image shape: {sample['google_img'].shape}")
|
||
print(f" Yandex image shape: {sample['yandex_img'].shape}")
|
||
print(f" Same domain: {sample['same_domain']}")
|
||
print(f" Index: {sample['idx'].item()}")
|
||
|
||
# 2. Создание даталоадеров
|
||
print("\n2. СОЗДАНИЕ ДАТАЛОАДЕРОВ")
|
||
print("-" * 40)
|
||
|
||
train_loader, val_loader = create_data_loaders(
|
||
root_dir=config["data_dir"],
|
||
batch_size=config["batch_size"],
|
||
train_split=config["train_split"],
|
||
num_workers=config["num_workers"],
|
||
image_size=config["image_size"],
|
||
augment_train=True,
|
||
augment_val=False,
|
||
device=device,
|
||
)
|
||
|
||
print(f"Train batches: {len(train_loader)}")
|
||
print(f"Val batches: {len(val_loader)}")
|
||
|
||
# 3. Создание модели
|
||
print("\n3. СОЗДАНИЕ МОДЕЛИ")
|
||
print("-" * 40)
|
||
|
||
model = SimilarityCNN(
|
||
input_channels=3,
|
||
hidden_channels=64,
|
||
num_blocks=4,
|
||
dropout_rate=0.3,
|
||
use_batch_norm=True,
|
||
).to(device)
|
||
|
||
print(f"Параметры модели: {sum(p.numel() for p in model.parameters()):,}")
|
||
|
||
# 4. Тестирование на одном батче
|
||
print("\n4. ТЕСТИРОВАНИЕ НА ОДНОМ БАТЧЕ")
|
||
print("-" * 40)
|
||
|
||
# Получаем батч из train_loader
|
||
for batch in train_loader:
|
||
google_img = batch["google_img"].to(device)
|
||
yandex_img = batch["yandex_img"].to(device)
|
||
same_domain = batch["same_domain"].float().to(device).unsqueeze(1)
|
||
|
||
print(f"Batch size: {google_img.shape[0]}")
|
||
print(f"Image shape: {google_img.shape[1:]}")
|
||
print(f"Same domain labels: {same_domain.squeeze().tolist()}")
|
||
|
||
# Предсказание схожести
|
||
with torch.no_grad():
|
||
predictions = model.predict_similarity(google_img, yandex_img)
|
||
print(f"\nПредсказания схожести:")
|
||
for i in range(len(predictions)):
|
||
print(
|
||
f" Sample {i}: {predictions[i].item():.4f} (target: {same_domain[i].item():.1f})"
|
||
)
|
||
|
||
# Расчет потерь
|
||
loss_fn = SimilarityLoss().to(device)
|
||
loss = loss_fn(predictions, same_domain)
|
||
print(f"\nПотери на батче: {loss.item():.4f}")
|
||
|
||
# Расчет метрик
|
||
metrics = loss_fn.compute_metrics(predictions, same_domain)
|
||
print("\nМетрики на батче:")
|
||
for key, value in metrics.items():
|
||
print(f" {key}: {value:.4f}")
|
||
|
||
break # Только первый батч
|
||
|
||
# 5. Обучение на одном эпохе (демонстрация)
|
||
print("\n5. ДЕМОНСТРАЦИЯ ОБУЧЕНИЯ НА ОДНОЙ ЭПОХЕ")
|
||
print("-" * 40)
|
||
|
||
optimizer = torch.optim.Adam(model.parameters(), lr=2e-4)
|
||
model.train()
|
||
|
||
total_loss = 0
|
||
total_samples = 0
|
||
|
||
for batch_idx, batch in enumerate(train_loader):
|
||
if batch_idx >= 3: # Ограничиваем 3 батчами для демонстрации
|
||
break
|
||
|
||
google_img = batch["google_img"].to(device)
|
||
yandex_img = batch["yandex_img"].to(device)
|
||
same_domain = batch["same_domain"].float().to(device).unsqueeze(1)
|
||
|
||
optimizer.zero_grad()
|
||
|
||
predictions = model(google_img, yandex_img)
|
||
loss = loss_fn(predictions, same_domain)
|
||
|
||
loss.backward()
|
||
optimizer.step()
|
||
|
||
batch_loss = loss.item() * google_img.size(0)
|
||
total_loss += batch_loss
|
||
total_samples += google_img.size(0)
|
||
|
||
print(f"Batch {batch_idx + 1}: loss = {loss.item():.4f}")
|
||
|
||
avg_loss = total_loss / total_samples
|
||
print(f"\nСредние потери за 3 батча: {avg_loss:.4f}")
|
||
|
||
# 6. Валидация
|
||
print("\n6. ВАЛИДАЦИЯ")
|
||
print("-" * 40)
|
||
|
||
model.eval()
|
||
val_loss = 0
|
||
val_samples = 0
|
||
|
||
with torch.no_grad():
|
||
for batch_idx, batch in enumerate(val_loader):
|
||
if batch_idx >= 2: # Ограничиваем 2 батчами для демонстрации
|
||
break
|
||
|
||
google_img = batch["google_img"].to(device)
|
||
yandex_img = batch["yandex_img"].to(device)
|
||
same_domain = batch["same_domain"].float().to(device).unsqueeze(1)
|
||
|
||
predictions = model.predict_similarity(google_img, yandex_img)
|
||
loss = loss_fn(predictions, same_domain)
|
||
|
||
val_loss += loss.item() * google_img.size(0)
|
||
val_samples += google_img.size(0)
|
||
|
||
print(f"Val batch {batch_idx + 1}: loss = {loss.item():.4f}")
|
||
|
||
avg_val_loss = val_loss / val_samples
|
||
print(f"\nСредние потери на валидации: {avg_val_loss:.4f}")
|
||
|
||
# 7. Пример использования для отдельных изображений
|
||
print("\n7. ПРИМЕР ДЛЯ ОТДЕЛЬНЫХ ИЗОБРАЖЕНИЙ")
|
||
print("-" * 40)
|
||
|
||
# Берем два примера из датасета
|
||
sample1 = dataset[0]
|
||
sample2 = dataset[1]
|
||
|
||
# Подготавливаем тензоры
|
||
img1_1 = sample1["google_img"].unsqueeze(0).to(device)
|
||
img1_2 = sample1["yandex_img"].unsqueeze(0).to(device)
|
||
|
||
img2_1 = sample2["google_img"].unsqueeze(0).to(device)
|
||
img2_2 = sample2["yandex_img"].unsqueeze(0).to(device)
|
||
|
||
# Предсказания
|
||
with torch.no_grad():
|
||
# Сравнение пар из одного домена
|
||
sim_same1 = model.predict_similarity(img1_1, img1_2)
|
||
sim_same2 = model.predict_similarity(img2_1, img2_2)
|
||
|
||
# Сравнение пар из разных доменов
|
||
sim_diff1 = model.predict_similarity(img1_1, img2_2)
|
||
sim_diff2 = model.predict_similarity(img2_1, img1_2)
|
||
|
||
print("Сравнение пар изображений:")
|
||
print(f" Пара 1 (один домен): {sim_same1.item():.4f}")
|
||
print(f" Пара 2 (один домен): {sim_same2.item():.4f}")
|
||
print(f" Разные домены 1: {sim_diff1.item():.4f}")
|
||
print(f" Разные домены 2: {sim_diff2.item():.4f}")
|
||
|
||
print("\n" + "=" * 60)
|
||
print("ПРИМЕР ЗАВЕРШЕН")
|
||
print("=" * 60)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|