""" Тестовый скрипт для проверки GAN trainer. """ import sys from pathlib import Path import numpy as np import torch import torch.nn as nn from torch.utils.data import DataLoader, Dataset # Добавляем путь к модулям sys.path.append(str(Path(__file__).parent.parent.parent)) from models.GAN.gan import create_image_gan from models.GAN.trainer import GANTrainer class SimpleDataset(Dataset): """Простой датасет для тестирования.""" def __init__(self, num_samples=100, image_size=(256, 256)): self.num_samples = num_samples self.image_size = image_size def __len__(self): return self.num_samples def __getitem__(self, idx): # Создаем случайные изображения для тестирования yandex_img = torch.randn(3, self.image_size[0], self.image_size[1]) google_img = torch.randn(3, self.image_size[0], self.image_size[1]) return {"yandex_img": yandex_img, "google_img": google_img} def test_gan_model(): """Тестирование GAN модели.""" print("Тестирование GAN модели...") # Создаем модель model = create_image_gan( input_channels=3, output_channels=3, gan_mode="vanilla", lambda_L1=100.0, use_cuda=False, # Используем CPU для тестирования ) # Тестируем forward pass batch_size = 2 image_size = (256, 256) yandex_input = torch.randn(batch_size, 3, *image_size) with torch.no_grad(): output = model(yandex_input) print(f"Входной размер: {yandex_input.shape}") print(f"Выходной размер: {output.shape}") print(f"Диапазон выходных значений: [{output.min():.3f}, {output.max():.3f}]") # Проверяем, что выход в диапазоне [-1, 1] (из-за Tanh) assert output.min() >= -1.0 and output.max() <= 1.0, "Выход не в диапазоне [-1, 1]" print("✓ Forward pass работает корректно") return model def test_generator_step(): """Тестирование шага генератора.""" print("\nТестирование шага генератора...") model = create_image_gan(use_cuda=False) model.train_mode() batch_size = 2 yandex_img = torch.randn(batch_size, 3, 256, 256) google_img = torch.randn(batch_size, 3, 256, 256) # Тестируем generator_step total_loss, gan_loss, l1_loss = model.generator_step(yandex_img, google_img) print(f"Total loss: {total_loss.item():.6f}") print(f"GAN loss: {gan_loss.item():.6f}") print(f"L1 loss: {l1_loss.item():.6f}") assert total_loss.item() > 0, "Потери должны быть положительными" print("✓ Шаг генератора работает корректно") def test_discriminator_step(): """Тестирование шага дискриминатора.""" print("\nТестирование шага дискриминатора...") model = create_image_gan(use_cuda=False) model.train_mode() batch_size = 2 yandex_img = torch.randn(batch_size, 3, 256, 256) google_img = torch.randn(batch_size, 3, 256, 256) # Генерируем fake изображение with torch.no_grad(): fake_google_img = model.generator(yandex_img) # Тестируем discriminator_step total_loss, real_loss, fake_loss = model.discriminator_step( yandex_img, google_img, fake_google_img ) print(f"Total loss: {total_loss.item():.6f}") print(f"Real loss: {real_loss.item():.6f}") print(f"Fake loss: {fake_loss.item():.6f}") assert total_loss.item() > 0, "Потери должны быть положительными" print("✓ Шаг дискриминатора работает корректно") def test_trainer_initialization(): """Тестирование инициализации тренера.""" print("\nТестирование инициализации тренера...") # Создаем модель model = create_image_gan(use_cuda=False) # Создаем даталоадеры train_dataset = SimpleDataset(num_samples=50) val_dataset = SimpleDataset(num_samples=10) train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True) val_loader = DataLoader(val_dataset, batch_size=2, shuffle=False) # Конфигурация config = { "learning_rate": 2e-4, "beta1": 0.5, "beta2": 0.999, "output_dir": "test_runs/gan", "early_stopping_patience": 10, } # Создаем тренер device = torch.device("cpu") trainer = GANTrainer( model=model, train_loader=train_loader, val_loader=val_loader, device=device, config=config, ) print(f"Тренер создан успешно") print(f"Оптимизатор генератора: {type(trainer.optimizer_G).__name__}") print(f"Оптимизатор дискриминатора: {type(trainer.optimizer_D).__name__}") print(f"Выходная директория: {trainer.output_dir}") assert trainer.output_dir.exists(), "Выходная директория не создана" print("✓ Тренер инициализирован корректно") return trainer, train_loader, val_loader def test_train_epoch(): """Тестирование одной эпохи обучения.""" print("\nТестирование одной эпохи обучения...") model = create_image_gan(use_cuda=False) train_dataset = SimpleDataset(num_samples=20) train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True) val_loader = DataLoader(SimpleDataset(num_samples=5), batch_size=2, shuffle=False) config = { "learning_rate": 2e-4, "beta1": 0.5, "beta2": 0.999, "output_dir": "test_runs/gan", } device = torch.device("cpu") trainer = GANTrainer( model=model, train_loader=train_loader, val_loader=val_loader, device=device, config=config, ) # Запускаем одну эпоху обучения avg_g_loss, avg_d_loss = trainer.train_epoch() print(f"Средние потери за эпоху:") print(f" Генератор: {avg_g_loss:.6f}") print(f" Дискриминатор: {avg_d_loss:.6f}") assert avg_g_loss > 0, "Потери генератора должны быть положительными" assert avg_d_loss > 0, "Потери дискриминатора должны быть положительными" print("✓ Эпоха обучения завершена успешно") def test_validation(): """Тестирование валидации.""" print("\nТестирование валидации...") model = create_image_gan(use_cuda=False) train_loader = DataLoader(SimpleDataset(num_samples=10), batch_size=2, shuffle=True) val_loader = DataLoader(SimpleDataset(num_samples=5), batch_size=2, shuffle=False) config = { "learning_rate": 2e-4, "beta1": 0.5, "beta2": 0.999, "output_dir": "test_runs/gan", } device = torch.device("cpu") trainer = GANTrainer( model=model, train_loader=train_loader, val_loader=val_loader, device=device, config=config, ) # Запускаем валидацию val_g_loss, val_d_loss = trainer.validate() print(f"Потери на валидации:") print(f" Генератор: {val_g_loss:.6f}") print(f" Дискриминатор: {val_d_loss:.6f}") assert val_g_loss > 0, "Потери генератора должны быть положительными" assert val_d_loss > 0, "Потери дискриминатора должны быть положительными" print("✓ Валидация завершена успешно") def test_checkpoint_saving(): """Тестирование сохранения чекпоинтов.""" print("\nТестирование сохранения чекпоинтов...") model = create_image_gan(use_cuda=False) train_loader = DataLoader(SimpleDataset(num_samples=10), batch_size=2, shuffle=True) val_loader = DataLoader(SimpleDataset(num_samples=5), batch_size=2, shuffle=False) config = { "learning_rate": 2e-4, "beta1": 0.5, "beta2": 0.999, "output_dir": "test_runs/gan_checkpoint", } device = torch.device("cpu") trainer = GANTrainer( model=model, train_loader=train_loader, val_loader=val_loader, device=device, config=config, ) # Сохраняем чекпоинт trainer.save_checkpoint(is_best=True) # Проверяем, что файлы созданы checkpoint_files = list(trainer.output_dir.glob("*.pth")) print(f"Создано файлов чекпоинтов: {len(checkpoint_files)}") for file in checkpoint_files: print(f" - {file.name}") assert len(checkpoint_files) > 0, "Файлы чекпоинтов не созданы" print("✓ Чекпоинты сохранены успешно") # Тестируем загрузку чекпоинта checkpoint_path = checkpoint_files[0] print(f"\nТестируем загрузку чекпоинта: {checkpoint_path}") # Создаем новую модель и тренер new_model = create_image_gan(use_cuda=False) new_trainer = GANTrainer( model=new_model, train_loader=train_loader, val_loader=val_loader, device=device, config=config, ) # Загружаем чекпоинт new_trainer.load_checkpoint(str(checkpoint_path)) print(f"Загружен чекпоинт эпохи: {new_trainer.current_epoch + 1}") print("✓ Чекпоинт загружен успешно") def main(): """Основная функция тестирования.""" print("=" * 60) print("Начало тестирования GAN trainer") print("=" * 60) try: # Запускаем все тесты test_gan_model() test_generator_step() test_discriminator_step() test_trainer_initialization() test_train_epoch() test_validation() test_checkpoint_saving() print("\n" + "=" * 60) print("Все тесты пройдены успешно! ✓") print("=" * 60) except Exception as e: print(f"\nОшибка при тестировании: {e}") import traceback traceback.print_exc() return 1 return 0 if __name__ == "__main__": # Очищаем тестовые директории import shutil test_dirs = ["test_runs/gan", "test_runs/gan_checkpoint"] for dir_path in test_dirs: if Path(dir_path).exists(): shutil.rmtree(dir_path) # Запускаем тесты exit_code = main() # Очищаем после тестов for dir_path in test_dirs: if Path(dir_path).exists(): shutil.rmtree(dir_path) exit(exit_code)