""" Минимальный пример использования GAN trainer для преобразования Yandex → Google карт. Этот пример показывает самый простой способ использования тренера. """ import sys from pathlib import Path import torch from torch.utils.data import DataLoader, Dataset # Добавляем путь к модулям sys.path.append(str(Path(__file__).parent.parent.parent)) from models.GAN.gan import create_image_gan from models.GAN.trainer import GANTrainer class SimpleMapDataset(Dataset): """Простой датасет с фиктивными данными для примера.""" def __init__(self, num_samples=100, image_size=(256, 256)): self.num_samples = num_samples self.image_size = image_size def __len__(self): return self.num_samples def __getitem__(self, idx): # Создаем фиктивные изображения # В реальном коде замените на загрузку реальных изображений yandex_img = torch.randn(3, self.image_size[0], self.image_size[1]) google_img = torch.randn(3, self.image_size[0], self.image_size[1]) return {"yandex_img": yandex_img, "google_img": google_img} def main(): """Основная функция минимального примера.""" print("Минимальный пример использования GAN trainer") print("=" * 50) # 1. Конфигурация (минимальный набор параметров) config = { "learning_rate": 2e-4, "batch_size": 4, "output_dir": "runs/gan_minimal", } # 2. Устройство (CPU или GPU) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Используемое устройство: {device}") # 3. Создание модели print("\nСоздание GAN модели...") model = create_image_gan( input_channels=3, output_channels=3, gan_mode="vanilla", # Простейший режим lambda_L1=100.0, # Стандартный вес L1 потерь use_cuda=(device.type == "cuda"), ) # 4. Создание даталоадеров print("Создание даталоадеров...") train_dataset = SimpleMapDataset(num_samples=50) val_dataset = SimpleMapDataset(num_samples=10) train_loader = DataLoader( train_dataset, batch_size=config["batch_size"], shuffle=True, num_workers=0, ) val_loader = DataLoader( val_dataset, batch_size=config["batch_size"], shuffle=False, num_workers=0, ) print(f" Обучающих примеров: {len(train_dataset)}") print(f" Валидационных примеров: {len(val_dataset)}") # 5. Создание тренера print("\nСоздание тренера...") trainer = GANTrainer( model=model, train_loader=train_loader, val_loader=val_loader, device=device, config=config, ) # 6. Обучение на небольшом количестве эпох print("\nЗапуск обучения (3 эпохи для примера)...") print("=" * 50) trainer.train(num_epochs=3) # 7. Генерация примеров print("\nГенерация примеров преобразования...") model.eval() # Создаем тестовые данные test_yandex = torch.randn(2, 3, 256, 256).to(device) with torch.no_grad(): generated_google = model(test_yandex) print(f"Входные изображения: {test_yandex.shape}") print(f"Сгенерированные изображения: {generated_google.shape}") print( f"Диапазон значений: [{generated_google.min():.3f}, {generated_google.max():.3f}]" ) # 8. Сохранение финальной модели print("\nСохранение модели...") model_save_path = "gan_model_minimal.pth" torch.save(model.state_dict(), model_save_path) print(f"Модель сохранена в: {model_save_path}") print("\n" + "=" * 50) print("Минимальный пример завершен!") print("\nДля реального использования:") print("1. Замените SimpleMapDataset на ваш реальный датасет") print("2. Настройте параметры в config") print("3. Увеличьте количество эпох (например, до 100)") print("4. Используйте реальные изображения карт") print("=" * 50) if __name__ == "__main__": main()