import os import sys import torch import torch.nn as nn # Добавляем путь к модулю sys.path.append( os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) ) from gan import ( DiscriminatorPatchGAN, GeneratorUNet, ImageGAN, create_image_gan, initialize_gan_weights, ) def test_generator(): """Тестирование генератора""" print("=" * 60) print("Тестирование генератора...") print("=" * 60) # Создаем генератор generator = GeneratorUNet(in_channels=3, out_channels=3) # Инициализируем веса generator.apply( lambda m: ( nn.init.normal_(m.weight.data, 0.0, 0.02) if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d) else None ) ) # Создаем тестовый входной тензор (Yandex изображение) batch_size = 2 height, width = 700, 700 yandex_image = torch.randn(batch_size, 3, height, width) print(f"Размер входного изображения: {yandex_image.shape}") print( f"Количество параметров генератора: {sum(p.numel() for p in generator.parameters()):,}" ) # Прямой проход with torch.no_grad(): generated_image = generator(yandex_image) print(f"Размер сгенерированного изображения: {generated_image.shape}") print( f"Диапазон значений сгенерированного изображения: [{generated_image.min():.3f}, {generated_image.max():.3f}]" ) # Проверка размеров assert generated_image.shape == (batch_size, 3, height, width), ( f"Ожидался размер {(batch_size, 3, height, width)}, получен {generated_image.shape}" ) print("✓ Генератор работает корректно!") return generator def test_discriminator(): """Тестирование дискриминатора""" print("\n" + "=" * 60) print("Тестирование дискриминатора...") print("=" * 60) # Создаем дискриминатор discriminator = DiscriminatorPatchGAN(in_channels=6) # 3 + 3 канала # Инициализируем веса discriminator.apply( lambda m: ( nn.init.normal_(m.weight.data, 0.0, 0.02) if isinstance(m, nn.Conv2d) else None ) ) # Создаем тестовые тензоры batch_size = 2 height, width = 700, 700 yandex_image = torch.randn(batch_size, 3, height, width) google_image = torch.randn(batch_size, 3, height, width) print(f"Размер Yandex изображения: {yandex_image.shape}") print(f"Размер Google изображения: {google_image.shape}") print( f"Количество параметров дискриминатора: {sum(p.numel() for p in discriminator.parameters()):,}" ) # Прямой проход with torch.no_grad(): prediction = discriminator(yandex_image, google_image) print(f"Размер выхода дискриминатора: {prediction.shape}") print( f"Диапазон значений предсказания: [{prediction.min():.3f}, {prediction.max():.3f}]" ) # Проверка размеров (PatchGAN выдает карту вероятностей) expected_height = 43 # Для изображения 700x700 после 4 downsampling блоков expected_width = 43 assert prediction.shape == (batch_size, 1, expected_height, expected_width), ( f"Ожидался размер {(batch_size, 1, expected_height, expected_width)}, получен {prediction.shape}" ) print( f"✓ Дискриминатор работает корректно! Выходной размер: {prediction.shape[2]}x{prediction.shape[3]}" ) return discriminator def test_gan_model(): """Тестирование полной GAN модели""" print("\n" + "=" * 60) print("Тестирование полной GAN модели...") print("=" * 60) # Создаем GAN модель gan = ImageGAN( input_channels=3, output_channels=3, gan_mode="vanilla", lambda_L1=100.0, use_cuda=False, # Тестируем на CPU для простоты ) print(f"Устройство модели: {gan.device}") print( f"Количество параметров генератора: {sum(p.numel() for p in gan.generator.parameters()):,}" ) print( f"Количество параметров дискриминатора: {sum(p.numel() for p in gan.discriminator.parameters()):,}" ) print(f"Общее количество параметров: {sum(p.numel() for p in gan.parameters()):,}") # Создаем тестовые данные batch_size = 2 height, width = 700, 700 yandex_image = torch.randn(batch_size, 3, height, width) real_google_image = torch.randn(batch_size, 3, height, width) print(f"\nТестирование прямого прохода...") with torch.no_grad(): generated_image = gan(yandex_image) print(f"Размер сгенерированного изображения: {generated_image.shape}") print(f"\nТестирование шага генератора...") gan.train_mode() # Тестируем шаг генератора total_loss, gan_loss, l1_loss = gan.generator_step(yandex_image, real_google_image) print(f"Общие потери генератора: {total_loss.item():.6f}") print(f"Потери GAN: {gan_loss.item():.6f}") print(f"Потери L1: {l1_loss.item():.6f}") print(f"\nТестирование шага дискриминатора...") # Создаем сгенерированное изображение для дискриминатора with torch.no_grad(): fake_google_image = gan.generator(yandex_image) total_d_loss, real_loss, fake_loss = gan.discriminator_step( yandex_image, real_google_image, fake_google_image ) print(f"Общие потери дискриминатора: {total_d_loss.item():.6f}") print(f"Потери на реальных изображениях: {real_loss.item():.6f}") print(f"Потери на сгенерированных изображениях: {fake_loss.item():.6f}") print(f"\nТестирование режимов обучения/оценки...") gan.eval_mode() print(f"Генератор в режиме eval: {not gan.generator.training}") print(f"Дискриминатор в режиме eval: {not gan.discriminator.training}") gan.train_mode() print(f"Генератор в режиме train: {gan.generator.training}") print(f"Дискриминатор в режиме train: {gan.discriminator.training}") print("\n✓ Полная GAN модель работает корректно!") return gan def test_factory_function(): """Тестирование фабричной функции""" print("\n" + "=" * 60) print("Тестирование фабричной функции...") print("=" * 60) # Тестируем разные режимы GAN for gan_mode in ["vanilla", "lsgan"]: print(f"\nСоздание GAN в режиме '{gan_mode}'...") gan = create_image_gan( input_channels=3, output_channels=3, gan_mode=gan_mode, lambda_L1=100.0, use_cuda=False, ) print(f" Режим GAN: {gan.gan_loss.gan_mode}") print(f" Вес L1 потерь: {gan.lambda_L1}") print(f" Устройство: {gan.device}") # Быстрая проверка прямого прохода batch_size = 1 yandex_image = torch.randn(batch_size, 3, 700, 700) with torch.no_grad(): generated = gan(yandex_image) print(f" Размер выхода: {generated.shape}") print(f" ✓ GAN в режиме '{gan_mode}' создан успешно") print("\n✓ Фабричная функция работает корректно!") def test_weights_initialization(): """Тестирование инициализации весов""" print("\n" + "=" * 60) print("Тестирование инициализации весов...") print("=" * 60) # Создаем модели generator = GeneratorUNet(3, 3) discriminator = DiscriminatorPatchGAN(6) # Инициализируем веса initialize_gan_weights(generator, discriminator) # Проверяем средние значения весов def check_weights_mean(model, model_name): conv_weights = [] for name, param in model.named_parameters(): if "weight" in name and ( "conv" in name.lower() or "Conv" in str(param.__class__) ): conv_weights.append(param.data.mean().item()) if conv_weights: avg_mean = sum(conv_weights) / len(conv_weights) print(f" Среднее значение весов Conv слоев в {model_name}: {avg_mean:.6f}") # Проверяем, что веса инициализированы около 0 assert abs(avg_mean) < 0.1, f"Веса {model_name} не инициализированы около 0" check_weights_mean(generator, "генераторе") check_weights_mean(discriminator, "дискриминаторе") print("✓ Инициализация весов работает корректно!") def test_memory_usage(): """Тестирование использования памяти""" print("\n" + "=" * 60) print("Тестирование использования памяти...") print("=" * 60) import os import psutil # Получаем текущее использование памяти process = psutil.Process(os.getpid()) memory_before = process.memory_info().rss / 1024 / 1024 # в MB print(f"Память до создания моделей: {memory_before:.2f} MB") # Создаем несколько моделей models = [] for i in range(3): gan = create_image_gan(use_cuda=False) models.append(gan) # Делаем тестовый проход batch_size = 1 yandex_image = torch.randn(batch_size, 3, 700, 700) real_google_image = torch.randn(batch_size, 3, 700, 700) with torch.no_grad(): _ = gan(yandex_image) _ = gan.generator_step(yandex_image, real_google_image) memory_after = process.memory_info().rss / 1024 / 1024 # в MB memory_used = memory_after - memory_before print(f"Память после создания моделей: {memory_after:.2f} MB") print(f"Использовано памяти: {memory_used:.2f} MB") # Очищаем модели del models import gc gc.collect() memory_final = process.memory_info().rss / 1024 / 1024 print(f"Память после очистки: {memory_final:.2f} MB") print("✓ Тестирование памяти завершено!") def main(): """Основная функция тестирования""" print("Начало тестирования GAN архитектуры для преобразования Yandex → Google") print("Размер изображения: 700x700 пикселей") print("=" * 60) try: # Запускаем все тесты test_generator() test_discriminator() test_gan_model() test_factory_function() test_weights_initialization() test_memory_usage() print("\n" + "=" * 60) print("ВСЕ ТЕСТЫ ПРОЙДЕНЫ УСПЕШНО! 🎉") print("=" * 60) print("\nАрхитектура GAN готова к использованию для преобразования") print("изображений из стиля Yandex в стиль Google.") print("\nОсновные характеристики:") print(" • Генератор: U-Net архитектура") print(" • Дискриминатор: PatchGAN (43x43 патчей)") print(" • Размер входных/выходных изображений: 700x700") print(" • Поддержка режимов: vanilla, lsgan") print(" • L1 регуляризация для сохранения структуры") except Exception as e: print(f"\n❌ Ошибка при тестировании: {e}") import traceback traceback.print_exc() return 1 return 0 if __name__ == "__main__": exit_code = main() sys.exit(exit_code)