""" Пример обучения GAN модели для преобразования Yandex → Google карт. Этот скрипт показывает, как использовать GANTrainer для обучения модели. """ import sys from pathlib import Path import torch from torch.utils.data import DataLoader # Добавляем путь к модулям sys.path.append(str(Path(__file__).parent.parent.parent)) from models.GAN.gan import create_image_gan from models.GAN.trainer import GANTrainer def create_simple_config(): """Создает простую конфигурацию для обучения.""" config = { # Параметры оптимизатора "learning_rate": 2e-4, "beta1": 0.5, "beta2": 0.999, # Параметры обучения "batch_size": 4, "epochs": 100, # Параметры GAN "gan_mode": "vanilla", # "vanilla", "lsgan", или "wgangp" "lambda_L1": 100.0, # Вес L1 потерь # Регуляризация "grad_clip": 1.0, # Ранняя остановка "early_stopping_patience": 20, # Выходные данные "output_dir": "runs/gan_training", # Логирование "log_interval": 10, # Логировать каждые N батчей "save_interval": 5, # Сохранять чекпоинт каждые N эпох } return config def create_advanced_config(): """Создает расширенную конфигурацию для обучения.""" config = { # Параметры оптимизатора "learning_rate": 2e-4, "beta1": 0.5, "beta2": 0.999, # Планировщик learning rate "use_scheduler": True, "scheduler_type": "linear", # "linear", "cosine", или "plateau" "scheduler_start_epoch": 50, "scheduler_end_epoch": 100, # Параметры обучения "batch_size": 8, "epochs": 200, # Параметры GAN "gan_mode": "lsgan", # LSGAN обычно более стабилен "lambda_L1": 100.0, # Аугментация данных "augmentation": { "random_crop": True, "crop_size": 256, "random_flip": True, "color_jitter": True, "brightness": 0.2, "contrast": 0.2, "saturation": 0.2, "hue": 0.1, }, # Регуляризация "grad_clip": 1.0, "weight_decay": 1e-4, # Ранняя остановка "early_stopping_patience": 30, "early_stopping_min_delta": 1e-4, # Выходные данные "output_dir": "runs/gan_advanced", # Логирование "log_interval": 20, "save_interval": 10, "save_best_only": True, # Сохранять только лучшую модель # Визуализация "visualize_samples": True, "num_visualize": 4, "visualize_interval": 5, # Визуализировать каждые N эпох } return config def print_config_summary(config): """Печатает сводку конфигурации.""" print("=" * 60) print("Конфигурация обучения GAN") print("=" * 60) print(f"\nПараметры модели:") print(f" Режим GAN: {config.get('gan_mode', 'vanilla')}") print(f" Вес L1 потерь: {config.get('lambda_L1', 100.0)}") print(f"\nПараметры обучения:") print(f" Learning rate: {config.get('learning_rate', 2e-4)}") print(f" Batch size: {config.get('batch_size', 4)}") print(f" Эпох: {config.get('epochs', 100)}") print(f" Beta1: {config.get('beta1', 0.5)}") print(f" Beta2: {config.get('beta2', 0.999)}") if config.get("use_scheduler", False): print(f" Планировщик LR: {config.get('scheduler_type', 'linear')}") print(f"\nРегуляризация:") print(f" Gradient clipping: {config.get('grad_clip', 1.0)}") if "weight_decay" in config: print(f" Weight decay: {config['weight_decay']}") print(f"\nРанняя остановка:") if config.get("early_stopping_patience", 0) > 0: print(f" Patience: {config['early_stopping_patience']} эпох") if "early_stopping_min_delta" in config: print(f" Min delta: {config['early_stopping_min_delta']}") print(f"\nВыходные данные:") print(f" Директория: {config.get('output_dir', 'runs/gan')}") print(f" Интервал сохранения: {config.get('save_interval', 5)} эпох") print(f"\nЛогирование:") print(f" Интервал логирования: {config.get('log_interval', 10)} батчей") print("=" * 60) def setup_training(): """Настраивает обучение.""" print("Настройка обучения GAN...") # Выбираем конфигурацию use_advanced = False # Измените на True для расширенной конфигурации if use_advanced: config = create_advanced_config() else: config = create_simple_config() # Печатаем сводку конфигурации print_config_summary(config) # Устройство device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"\nИспользуемое устройство: {device}") if device.type == "cuda": print(f" GPU: {torch.cuda.get_device_name(0)}") print( f" Память: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB" ) # Создаем модель print("\nСоздание модели...") model = create_image_gan( input_channels=3, output_channels=3, gan_mode=config.get("gan_mode", "vanilla"), lambda_L1=config.get("lambda_L1", 100.0), use_cuda=(device.type == "cuda"), ) # Создаем даталоадеры print("\nСоздание даталоадеров...") # ЗАМЕНИТЕ ЭТО НА ВАШИ РЕАЛЬНЫЕ ДАННЫЕ # Пример: # from your_dataset_module import create_data_loaders # train_loader, val_loader = create_data_loaders( # data_dir="ваш/путь/к/данным", # batch_size=config["batch_size"], # image_size=(256, 256), # augment=config.get("augmentation", None), # ) # Для примера создаем фиктивные даталоадеры # ВАЖНО: Замените это на реальные данные! print(" ВНИМАНИЕ: Используются фиктивные данные!") print(" Замените на реальные даталоадеры!") import numpy as np from torch.utils.data import Dataset class DummyDataset(Dataset): def __init__(self, num_samples=100): self.num_samples = num_samples def __len__(self): return self.num_samples def __getitem__(self, idx): # Фиктивные данные для примера yandex_img = torch.randn(3, 256, 256) google_img = torch.randn(3, 256, 256) return {"yandex_img": yandex_img, "google_img": google_img} train_dataset = DummyDataset(num_samples=100) val_dataset = DummyDataset(num_samples=20) train_loader = DataLoader( train_dataset, batch_size=config.get("batch_size", 4), shuffle=True, num_workers=0, ) val_loader = DataLoader( val_dataset, batch_size=config.get("batch_size", 4), shuffle=False, num_workers=0, ) print(f" Размер обучающего набора: {len(train_dataset)}") print(f" Размер валидационного набора: {len(val_dataset)}") print(f" Батчей в эпохе: {len(train_loader)}") # Создаем тренер print("\nСоздание тренера...") trainer = GANTrainer( model=model, train_loader=train_loader, val_loader=val_loader, device=device, config=config, ) return trainer, config def train_model(trainer, config): """Запускает обучение модели.""" print("\n" + "=" * 60) print("Начало обучения") print("=" * 60) epochs = config.get("epochs", 100) try: trainer.train(num_epochs=epochs) print("\n" + "=" * 60) print("Обучение завершено успешно!") print("=" * 60) except KeyboardInterrupt: print("\n\nОбучение прервано пользователем.") print("Сохранение текущего состояния...") trainer.save_checkpoint(is_best=False) except Exception as e: print(f"\n\nОшибка при обучении: {e}") import traceback traceback.print_exc() # Пытаемся сохранить чекпоинт при ошибке try: trainer.save_checkpoint(is_best=False) print("Текущее состояние сохранено.") except: print("Не удалось сохранить состояние.") def evaluate_model(trainer, test_loader=None): """Оценивает обученную модель.""" print("\n" + "=" * 60) print("Оценка модели") print("=" * 60) if test_loader is None: print("Тестовый даталоадер не предоставлен.") print("Используется валидационный даталоадер для оценки.") test_loader = trainer.val_loader metrics = trainer.evaluate(test_loader) print("\nМетрики оценки:") for key, value in metrics.items(): print(f" {key}: {value:.6f}") return metrics def generate_examples(model, device, num_examples=4): """Генерирует примеры преобразования.""" print("\n" + "=" * 60) print("Генерация примеров") print("=" * 60) model.eval() # Создаем фиктивные входные данные yandex_input = torch.randn(num_examples, 3, 256, 256).to(device) with torch.no_grad(): google_output = model(yandex_input) print(f"Сгенерировано {num_examples} примеров") print(f"Размер входных данных: {yandex_input.shape}") print(f"Размер выходных данных: {google_output.shape}") # Сохраняем примеры (в реальном коде сохраняйте как изображения) print("\nПримеры сгенерированы.") print("В реальном коде сохраняйте их как изображения для визуализации.") return yandex_input, google_output def main(): """Основная функция.""" print("=" * 60) print("Пример обучения GAN для преобразования Yandex → Google") print("=" * 60) # Настройка trainer, config = setup_training() # Обучение train_model(trainer, config) # Оценка (требует реальных тестовых данных) # evaluate_model(trainer) # Генерация примеров # generate_examples(trainer.model, trainer.device) print("\n" + "=" * 60) print("Скрипт завершен.") print("=" * 60) print("\nСледующие шаги:") print("1. Замените фиктивные даталоадеры на реальные данные") print("2. Настройте параметры в create_simple_config()") print("3. Запустите обучение с реальными данными") print("4. Визуализируйте результаты") print("=" * 60) if __name__ == "__main__": main()