Files
autopilot/models/GAN/test_gan.py
2026-02-20 16:52:02 +03:00

350 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import os
import sys
import torch
import torch.nn as nn
# Добавляем путь к модулю
sys.path.append(
os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
)
from gan import (
DiscriminatorPatchGAN,
GeneratorUNet,
ImageGAN,
create_image_gan,
initialize_gan_weights,
)
def test_generator():
"""Тестирование генератора"""
print("=" * 60)
print("Тестирование генератора...")
print("=" * 60)
# Создаем генератор
generator = GeneratorUNet(in_channels=3, out_channels=3)
# Инициализируем веса
generator.apply(
lambda m: (
nn.init.normal_(m.weight.data, 0.0, 0.02)
if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d)
else None
)
)
# Создаем тестовый входной тензор (Yandex изображение)
batch_size = 2
height, width = 700, 700
yandex_image = torch.randn(batch_size, 3, height, width)
print(f"Размер входного изображения: {yandex_image.shape}")
print(
f"Количество параметров генератора: {sum(p.numel() for p in generator.parameters()):,}"
)
# Прямой проход
with torch.no_grad():
generated_image = generator(yandex_image)
print(f"Размер сгенерированного изображения: {generated_image.shape}")
print(
f"Диапазон значений сгенерированного изображения: [{generated_image.min():.3f}, {generated_image.max():.3f}]"
)
# Проверка размеров
assert generated_image.shape == (batch_size, 3, height, width), (
f"Ожидался размер {(batch_size, 3, height, width)}, получен {generated_image.shape}"
)
print("✓ Генератор работает корректно!")
return generator
def test_discriminator():
"""Тестирование дискриминатора"""
print("\n" + "=" * 60)
print("Тестирование дискриминатора...")
print("=" * 60)
# Создаем дискриминатор
discriminator = DiscriminatorPatchGAN(in_channels=6) # 3 + 3 канала
# Инициализируем веса
discriminator.apply(
lambda m: (
nn.init.normal_(m.weight.data, 0.0, 0.02)
if isinstance(m, nn.Conv2d)
else None
)
)
# Создаем тестовые тензоры
batch_size = 2
height, width = 700, 700
yandex_image = torch.randn(batch_size, 3, height, width)
google_image = torch.randn(batch_size, 3, height, width)
print(f"Размер Yandex изображения: {yandex_image.shape}")
print(f"Размер Google изображения: {google_image.shape}")
print(
f"Количество параметров дискриминатора: {sum(p.numel() for p in discriminator.parameters()):,}"
)
# Прямой проход
with torch.no_grad():
prediction = discriminator(yandex_image, google_image)
print(f"Размер выхода дискриминатора: {prediction.shape}")
print(
f"Диапазон значений предсказания: [{prediction.min():.3f}, {prediction.max():.3f}]"
)
# Проверка размеров (PatchGAN выдает карту вероятностей)
expected_height = 43 # Для изображения 700x700 после 4 downsampling блоков
expected_width = 43
assert prediction.shape == (batch_size, 1, expected_height, expected_width), (
f"Ожидался размер {(batch_size, 1, expected_height, expected_width)}, получен {prediction.shape}"
)
print(
f"✓ Дискриминатор работает корректно! Выходной размер: {prediction.shape[2]}x{prediction.shape[3]}"
)
return discriminator
def test_gan_model():
"""Тестирование полной GAN модели"""
print("\n" + "=" * 60)
print("Тестирование полной GAN модели...")
print("=" * 60)
# Создаем GAN модель
gan = ImageGAN(
input_channels=3,
output_channels=3,
gan_mode="vanilla",
lambda_L1=100.0,
use_cuda=False, # Тестируем на CPU для простоты
)
print(f"Устройство модели: {gan.device}")
print(
f"Количество параметров генератора: {sum(p.numel() for p in gan.generator.parameters()):,}"
)
print(
f"Количество параметров дискриминатора: {sum(p.numel() for p in gan.discriminator.parameters()):,}"
)
print(f"Общее количество параметров: {sum(p.numel() for p in gan.parameters()):,}")
# Создаем тестовые данные
batch_size = 2
height, width = 700, 700
yandex_image = torch.randn(batch_size, 3, height, width)
real_google_image = torch.randn(batch_size, 3, height, width)
print(f"\nТестирование прямого прохода...")
with torch.no_grad():
generated_image = gan(yandex_image)
print(f"Размер сгенерированного изображения: {generated_image.shape}")
print(f"\nТестирование шага генератора...")
gan.train_mode()
# Тестируем шаг генератора
total_loss, gan_loss, l1_loss = gan.generator_step(yandex_image, real_google_image)
print(f"Общие потери генератора: {total_loss.item():.6f}")
print(f"Потери GAN: {gan_loss.item():.6f}")
print(f"Потери L1: {l1_loss.item():.6f}")
print(f"\nТестирование шага дискриминатора...")
# Создаем сгенерированное изображение для дискриминатора
with torch.no_grad():
fake_google_image = gan.generator(yandex_image)
total_d_loss, real_loss, fake_loss = gan.discriminator_step(
yandex_image, real_google_image, fake_google_image
)
print(f"Общие потери дискриминатора: {total_d_loss.item():.6f}")
print(f"Потери на реальных изображениях: {real_loss.item():.6f}")
print(f"Потери на сгенерированных изображениях: {fake_loss.item():.6f}")
print(f"\nТестирование режимов обучения/оценки...")
gan.eval_mode()
print(f"Генератор в режиме eval: {not gan.generator.training}")
print(f"Дискриминатор в режиме eval: {not gan.discriminator.training}")
gan.train_mode()
print(f"Генератор в режиме train: {gan.generator.training}")
print(f"Дискриминатор в режиме train: {gan.discriminator.training}")
print("\n✓ Полная GAN модель работает корректно!")
return gan
def test_factory_function():
"""Тестирование фабричной функции"""
print("\n" + "=" * 60)
print("Тестирование фабричной функции...")
print("=" * 60)
# Тестируем разные режимы GAN
for gan_mode in ["vanilla", "lsgan"]:
print(f"\nСоздание GAN в режиме '{gan_mode}'...")
gan = create_image_gan(
input_channels=3,
output_channels=3,
gan_mode=gan_mode,
lambda_L1=100.0,
use_cuda=False,
)
print(f" Режим GAN: {gan.gan_loss.gan_mode}")
print(f" Вес L1 потерь: {gan.lambda_L1}")
print(f" Устройство: {gan.device}")
# Быстрая проверка прямого прохода
batch_size = 1
yandex_image = torch.randn(batch_size, 3, 700, 700)
with torch.no_grad():
generated = gan(yandex_image)
print(f" Размер выхода: {generated.shape}")
print(f" ✓ GAN в режиме '{gan_mode}' создан успешно")
print("\n✓ Фабричная функция работает корректно!")
def test_weights_initialization():
"""Тестирование инициализации весов"""
print("\n" + "=" * 60)
print("Тестирование инициализации весов...")
print("=" * 60)
# Создаем модели
generator = GeneratorUNet(3, 3)
discriminator = DiscriminatorPatchGAN(6)
# Инициализируем веса
initialize_gan_weights(generator, discriminator)
# Проверяем средние значения весов
def check_weights_mean(model, model_name):
conv_weights = []
for name, param in model.named_parameters():
if "weight" in name and (
"conv" in name.lower() or "Conv" in str(param.__class__)
):
conv_weights.append(param.data.mean().item())
if conv_weights:
avg_mean = sum(conv_weights) / len(conv_weights)
print(f" Среднее значение весов Conv слоев в {model_name}: {avg_mean:.6f}")
# Проверяем, что веса инициализированы около 0
assert abs(avg_mean) < 0.1, f"Веса {model_name} не инициализированы около 0"
check_weights_mean(generator, "генераторе")
check_weights_mean(discriminator, "дискриминаторе")
print("✓ Инициализация весов работает корректно!")
def test_memory_usage():
"""Тестирование использования памяти"""
print("\n" + "=" * 60)
print("Тестирование использования памяти...")
print("=" * 60)
import os
import psutil
# Получаем текущее использование памяти
process = psutil.Process(os.getpid())
memory_before = process.memory_info().rss / 1024 / 1024 # в MB
print(f"Память до создания моделей: {memory_before:.2f} MB")
# Создаем несколько моделей
models = []
for i in range(3):
gan = create_image_gan(use_cuda=False)
models.append(gan)
# Делаем тестовый проход
batch_size = 1
yandex_image = torch.randn(batch_size, 3, 700, 700)
real_google_image = torch.randn(batch_size, 3, 700, 700)
with torch.no_grad():
_ = gan(yandex_image)
_ = gan.generator_step(yandex_image, real_google_image)
memory_after = process.memory_info().rss / 1024 / 1024 # в MB
memory_used = memory_after - memory_before
print(f"Память после создания моделей: {memory_after:.2f} MB")
print(f"Использовано памяти: {memory_used:.2f} MB")
# Очищаем модели
del models
import gc
gc.collect()
memory_final = process.memory_info().rss / 1024 / 1024
print(f"Память после очистки: {memory_final:.2f} MB")
print("✓ Тестирование памяти завершено!")
def main():
"""Основная функция тестирования"""
print("Начало тестирования GAN архитектуры для преобразования Yandex → Google")
print("Размер изображения: 700x700 пикселей")
print("=" * 60)
try:
# Запускаем все тесты
test_generator()
test_discriminator()
test_gan_model()
test_factory_function()
test_weights_initialization()
test_memory_usage()
print("\n" + "=" * 60)
print("ВСЕ ТЕСТЫ ПРОЙДЕНЫ УСПЕШНО! 🎉")
print("=" * 60)
print("\nАрхитектура GAN готова к использованию для преобразования")
print("изображений из стиля Yandex в стиль Google.")
print("\nОсновные характеристики:")
print(" • Генератор: U-Net архитектура")
print(" • Дискриминатор: PatchGAN (43x43 патчей)")
print(" • Размер входных/выходных изображений: 700x700")
print(" • Поддержка режимов: vanilla, lsgan")
print(" • L1 регуляризация для сохранения структуры")
except Exception as e:
print(f"\n❌ Ошибка при тестировании: {e}")
import traceback
traceback.print_exc()
return 1
return 0
if __name__ == "__main__":
exit_code = main()
sys.exit(exit_code)