""" Пример использования модели оценки схожести с даталоадером. """ import torch from dataloader import YaGoDataset, create_data_loaders from model import SimilarityCNN, SimilarityLoss def main(): """Основной пример использования.""" print("ПРИМЕР ИСПОЛЬЗОВАНИЯ МОДЕЛИ СХОЖЕСТИ С ДАТАЛОАДЕРОМ") print("=" * 60) # Конфигурация config = { "data_dir": r"C:\Users\admin\Projects\autopilot\datasets\ya_go_maps\images", "batch_size": 4, "image_size": (256, 256), "train_split": 0.8, "num_workers": 0, } # Устройство device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Устройство: {device}") # 1. Создание датасета print("\n1. СОЗДАНИЕ ДАТАСЕТА") print("-" * 40) dataset = YaGoDataset( root_dir=config["data_dir"], augment=False, image_size=config["image_size"], ) print(f"Размер датасета: {len(dataset)} пар изображений") # Получение примера из датасета sample = dataset[0] print(f"\nПример из датасета:") print(f" Google image shape: {sample['google_img'].shape}") print(f" Yandex image shape: {sample['yandex_img'].shape}") print(f" Same domain: {sample['same_domain']}") print(f" Index: {sample['idx'].item()}") # 2. Создание даталоадеров print("\n2. СОЗДАНИЕ ДАТАЛОАДЕРОВ") print("-" * 40) train_loader, val_loader = create_data_loaders( root_dir=config["data_dir"], batch_size=config["batch_size"], train_split=config["train_split"], num_workers=config["num_workers"], image_size=config["image_size"], augment_train=True, augment_val=False, device=device, ) print(f"Train batches: {len(train_loader)}") print(f"Val batches: {len(val_loader)}") # 3. Создание модели print("\n3. СОЗДАНИЕ МОДЕЛИ") print("-" * 40) model = SimilarityCNN( input_channels=3, hidden_channels=64, num_blocks=4, dropout_rate=0.3, use_batch_norm=True, ).to(device) print(f"Параметры модели: {sum(p.numel() for p in model.parameters()):,}") # 4. Тестирование на одном батче print("\n4. ТЕСТИРОВАНИЕ НА ОДНОМ БАТЧЕ") print("-" * 40) # Получаем батч из train_loader for batch in train_loader: google_img = batch["google_img"].to(device) yandex_img = batch["yandex_img"].to(device) same_domain = batch["same_domain"].float().to(device).unsqueeze(1) print(f"Batch size: {google_img.shape[0]}") print(f"Image shape: {google_img.shape[1:]}") print(f"Same domain labels: {same_domain.squeeze().tolist()}") # Предсказание схожести with torch.no_grad(): predictions = model.predict_similarity(google_img, yandex_img) print(f"\nПредсказания схожести:") for i in range(len(predictions)): print( f" Sample {i}: {predictions[i].item():.4f} (target: {same_domain[i].item():.1f})" ) # Расчет потерь loss_fn = SimilarityLoss().to(device) loss = loss_fn(predictions, same_domain) print(f"\nПотери на батче: {loss.item():.4f}") # Расчет метрик metrics = loss_fn.compute_metrics(predictions, same_domain) print("\nМетрики на батче:") for key, value in metrics.items(): print(f" {key}: {value:.4f}") break # Только первый батч # 5. Обучение на одном эпохе (демонстрация) print("\n5. ДЕМОНСТРАЦИЯ ОБУЧЕНИЯ НА ОДНОЙ ЭПОХЕ") print("-" * 40) optimizer = torch.optim.Adam(model.parameters(), lr=2e-4) model.train() total_loss = 0 total_samples = 0 for batch_idx, batch in enumerate(train_loader): if batch_idx >= 3: # Ограничиваем 3 батчами для демонстрации break google_img = batch["google_img"].to(device) yandex_img = batch["yandex_img"].to(device) same_domain = batch["same_domain"].float().to(device).unsqueeze(1) optimizer.zero_grad() predictions = model(google_img, yandex_img) loss = loss_fn(predictions, same_domain) loss.backward() optimizer.step() batch_loss = loss.item() * google_img.size(0) total_loss += batch_loss total_samples += google_img.size(0) print(f"Batch {batch_idx + 1}: loss = {loss.item():.4f}") avg_loss = total_loss / total_samples print(f"\nСредние потери за 3 батча: {avg_loss:.4f}") # 6. Валидация print("\n6. ВАЛИДАЦИЯ") print("-" * 40) model.eval() val_loss = 0 val_samples = 0 with torch.no_grad(): for batch_idx, batch in enumerate(val_loader): if batch_idx >= 2: # Ограничиваем 2 батчами для демонстрации break google_img = batch["google_img"].to(device) yandex_img = batch["yandex_img"].to(device) same_domain = batch["same_domain"].float().to(device).unsqueeze(1) predictions = model.predict_similarity(google_img, yandex_img) loss = loss_fn(predictions, same_domain) val_loss += loss.item() * google_img.size(0) val_samples += google_img.size(0) print(f"Val batch {batch_idx + 1}: loss = {loss.item():.4f}") avg_val_loss = val_loss / val_samples print(f"\nСредние потери на валидации: {avg_val_loss:.4f}") # 7. Пример использования для отдельных изображений print("\n7. ПРИМЕР ДЛЯ ОТДЕЛЬНЫХ ИЗОБРАЖЕНИЙ") print("-" * 40) # Берем два примера из датасета sample1 = dataset[0] sample2 = dataset[1] # Подготавливаем тензоры img1_1 = sample1["google_img"].unsqueeze(0).to(device) img1_2 = sample1["yandex_img"].unsqueeze(0).to(device) img2_1 = sample2["google_img"].unsqueeze(0).to(device) img2_2 = sample2["yandex_img"].unsqueeze(0).to(device) # Предсказания with torch.no_grad(): # Сравнение пар из одного домена sim_same1 = model.predict_similarity(img1_1, img1_2) sim_same2 = model.predict_similarity(img2_1, img2_2) # Сравнение пар из разных доменов sim_diff1 = model.predict_similarity(img1_1, img2_2) sim_diff2 = model.predict_similarity(img2_1, img1_2) print("Сравнение пар изображений:") print(f" Пара 1 (один домен): {sim_same1.item():.4f}") print(f" Пара 2 (один домен): {sim_same2.item():.4f}") print(f" Разные домены 1: {sim_diff1.item():.4f}") print(f" Разные домены 2: {sim_diff2.item():.4f}") print("\n" + "=" * 60) print("ПРИМЕР ЗАВЕРШЕН") print("=" * 60) if __name__ == "__main__": main()