350 lines
13 KiB
Python
350 lines
13 KiB
Python
import os
|
||
import sys
|
||
|
||
import torch
|
||
import torch.nn as nn
|
||
|
||
# Добавляем путь к модулю
|
||
sys.path.append(
|
||
os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||
)
|
||
|
||
from gan import (
|
||
DiscriminatorPatchGAN,
|
||
GeneratorUNet,
|
||
ImageGAN,
|
||
create_image_gan,
|
||
initialize_gan_weights,
|
||
)
|
||
|
||
|
||
def test_generator():
|
||
"""Тестирование генератора"""
|
||
print("=" * 60)
|
||
print("Тестирование генератора...")
|
||
print("=" * 60)
|
||
|
||
# Создаем генератор
|
||
generator = GeneratorUNet(in_channels=3, out_channels=3)
|
||
|
||
# Инициализируем веса
|
||
generator.apply(
|
||
lambda m: (
|
||
nn.init.normal_(m.weight.data, 0.0, 0.02)
|
||
if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d)
|
||
else None
|
||
)
|
||
)
|
||
|
||
# Создаем тестовый входной тензор (Yandex изображение)
|
||
batch_size = 2
|
||
height, width = 700, 700
|
||
yandex_image = torch.randn(batch_size, 3, height, width)
|
||
|
||
print(f"Размер входного изображения: {yandex_image.shape}")
|
||
print(
|
||
f"Количество параметров генератора: {sum(p.numel() for p in generator.parameters()):,}"
|
||
)
|
||
|
||
# Прямой проход
|
||
with torch.no_grad():
|
||
generated_image = generator(yandex_image)
|
||
|
||
print(f"Размер сгенерированного изображения: {generated_image.shape}")
|
||
print(
|
||
f"Диапазон значений сгенерированного изображения: [{generated_image.min():.3f}, {generated_image.max():.3f}]"
|
||
)
|
||
|
||
# Проверка размеров
|
||
assert generated_image.shape == (batch_size, 3, height, width), (
|
||
f"Ожидался размер {(batch_size, 3, height, width)}, получен {generated_image.shape}"
|
||
)
|
||
|
||
print("✓ Генератор работает корректно!")
|
||
return generator
|
||
|
||
|
||
def test_discriminator():
|
||
"""Тестирование дискриминатора"""
|
||
print("\n" + "=" * 60)
|
||
print("Тестирование дискриминатора...")
|
||
print("=" * 60)
|
||
|
||
# Создаем дискриминатор
|
||
discriminator = DiscriminatorPatchGAN(in_channels=6) # 3 + 3 канала
|
||
|
||
# Инициализируем веса
|
||
discriminator.apply(
|
||
lambda m: (
|
||
nn.init.normal_(m.weight.data, 0.0, 0.02)
|
||
if isinstance(m, nn.Conv2d)
|
||
else None
|
||
)
|
||
)
|
||
|
||
# Создаем тестовые тензоры
|
||
batch_size = 2
|
||
height, width = 700, 700
|
||
|
||
yandex_image = torch.randn(batch_size, 3, height, width)
|
||
google_image = torch.randn(batch_size, 3, height, width)
|
||
|
||
print(f"Размер Yandex изображения: {yandex_image.shape}")
|
||
print(f"Размер Google изображения: {google_image.shape}")
|
||
print(
|
||
f"Количество параметров дискриминатора: {sum(p.numel() for p in discriminator.parameters()):,}"
|
||
)
|
||
|
||
# Прямой проход
|
||
with torch.no_grad():
|
||
prediction = discriminator(yandex_image, google_image)
|
||
|
||
print(f"Размер выхода дискриминатора: {prediction.shape}")
|
||
print(
|
||
f"Диапазон значений предсказания: [{prediction.min():.3f}, {prediction.max():.3f}]"
|
||
)
|
||
|
||
# Проверка размеров (PatchGAN выдает карту вероятностей)
|
||
expected_height = 43 # Для изображения 700x700 после 4 downsampling блоков
|
||
expected_width = 43
|
||
assert prediction.shape == (batch_size, 1, expected_height, expected_width), (
|
||
f"Ожидался размер {(batch_size, 1, expected_height, expected_width)}, получен {prediction.shape}"
|
||
)
|
||
|
||
print(
|
||
f"✓ Дискриминатор работает корректно! Выходной размер: {prediction.shape[2]}x{prediction.shape[3]}"
|
||
)
|
||
return discriminator
|
||
|
||
|
||
def test_gan_model():
|
||
"""Тестирование полной GAN модели"""
|
||
print("\n" + "=" * 60)
|
||
print("Тестирование полной GAN модели...")
|
||
print("=" * 60)
|
||
|
||
# Создаем GAN модель
|
||
gan = ImageGAN(
|
||
input_channels=3,
|
||
output_channels=3,
|
||
gan_mode="vanilla",
|
||
lambda_L1=100.0,
|
||
use_cuda=False, # Тестируем на CPU для простоты
|
||
)
|
||
|
||
print(f"Устройство модели: {gan.device}")
|
||
print(
|
||
f"Количество параметров генератора: {sum(p.numel() for p in gan.generator.parameters()):,}"
|
||
)
|
||
print(
|
||
f"Количество параметров дискриминатора: {sum(p.numel() for p in gan.discriminator.parameters()):,}"
|
||
)
|
||
print(f"Общее количество параметров: {sum(p.numel() for p in gan.parameters()):,}")
|
||
|
||
# Создаем тестовые данные
|
||
batch_size = 2
|
||
height, width = 700, 700
|
||
|
||
yandex_image = torch.randn(batch_size, 3, height, width)
|
||
real_google_image = torch.randn(batch_size, 3, height, width)
|
||
|
||
print(f"\nТестирование прямого прохода...")
|
||
with torch.no_grad():
|
||
generated_image = gan(yandex_image)
|
||
|
||
print(f"Размер сгенерированного изображения: {generated_image.shape}")
|
||
|
||
print(f"\nТестирование шага генератора...")
|
||
gan.train_mode()
|
||
|
||
# Тестируем шаг генератора
|
||
total_loss, gan_loss, l1_loss = gan.generator_step(yandex_image, real_google_image)
|
||
|
||
print(f"Общие потери генератора: {total_loss.item():.6f}")
|
||
print(f"Потери GAN: {gan_loss.item():.6f}")
|
||
print(f"Потери L1: {l1_loss.item():.6f}")
|
||
|
||
print(f"\nТестирование шага дискриминатора...")
|
||
# Создаем сгенерированное изображение для дискриминатора
|
||
with torch.no_grad():
|
||
fake_google_image = gan.generator(yandex_image)
|
||
|
||
total_d_loss, real_loss, fake_loss = gan.discriminator_step(
|
||
yandex_image, real_google_image, fake_google_image
|
||
)
|
||
|
||
print(f"Общие потери дискриминатора: {total_d_loss.item():.6f}")
|
||
print(f"Потери на реальных изображениях: {real_loss.item():.6f}")
|
||
print(f"Потери на сгенерированных изображениях: {fake_loss.item():.6f}")
|
||
|
||
print(f"\nТестирование режимов обучения/оценки...")
|
||
gan.eval_mode()
|
||
print(f"Генератор в режиме eval: {not gan.generator.training}")
|
||
print(f"Дискриминатор в режиме eval: {not gan.discriminator.training}")
|
||
|
||
gan.train_mode()
|
||
print(f"Генератор в режиме train: {gan.generator.training}")
|
||
print(f"Дискриминатор в режиме train: {gan.discriminator.training}")
|
||
|
||
print("\n✓ Полная GAN модель работает корректно!")
|
||
return gan
|
||
|
||
|
||
def test_factory_function():
|
||
"""Тестирование фабричной функции"""
|
||
print("\n" + "=" * 60)
|
||
print("Тестирование фабричной функции...")
|
||
print("=" * 60)
|
||
|
||
# Тестируем разные режимы GAN
|
||
for gan_mode in ["vanilla", "lsgan"]:
|
||
print(f"\nСоздание GAN в режиме '{gan_mode}'...")
|
||
gan = create_image_gan(
|
||
input_channels=3,
|
||
output_channels=3,
|
||
gan_mode=gan_mode,
|
||
lambda_L1=100.0,
|
||
use_cuda=False,
|
||
)
|
||
|
||
print(f" Режим GAN: {gan.gan_loss.gan_mode}")
|
||
print(f" Вес L1 потерь: {gan.lambda_L1}")
|
||
print(f" Устройство: {gan.device}")
|
||
|
||
# Быстрая проверка прямого прохода
|
||
batch_size = 1
|
||
yandex_image = torch.randn(batch_size, 3, 700, 700)
|
||
|
||
with torch.no_grad():
|
||
generated = gan(yandex_image)
|
||
|
||
print(f" Размер выхода: {generated.shape}")
|
||
print(f" ✓ GAN в режиме '{gan_mode}' создан успешно")
|
||
|
||
print("\n✓ Фабричная функция работает корректно!")
|
||
|
||
|
||
def test_weights_initialization():
|
||
"""Тестирование инициализации весов"""
|
||
print("\n" + "=" * 60)
|
||
print("Тестирование инициализации весов...")
|
||
print("=" * 60)
|
||
|
||
# Создаем модели
|
||
generator = GeneratorUNet(3, 3)
|
||
discriminator = DiscriminatorPatchGAN(6)
|
||
|
||
# Инициализируем веса
|
||
initialize_gan_weights(generator, discriminator)
|
||
|
||
# Проверяем средние значения весов
|
||
def check_weights_mean(model, model_name):
|
||
conv_weights = []
|
||
for name, param in model.named_parameters():
|
||
if "weight" in name and (
|
||
"conv" in name.lower() or "Conv" in str(param.__class__)
|
||
):
|
||
conv_weights.append(param.data.mean().item())
|
||
|
||
if conv_weights:
|
||
avg_mean = sum(conv_weights) / len(conv_weights)
|
||
print(f" Среднее значение весов Conv слоев в {model_name}: {avg_mean:.6f}")
|
||
# Проверяем, что веса инициализированы около 0
|
||
assert abs(avg_mean) < 0.1, f"Веса {model_name} не инициализированы около 0"
|
||
|
||
check_weights_mean(generator, "генераторе")
|
||
check_weights_mean(discriminator, "дискриминаторе")
|
||
|
||
print("✓ Инициализация весов работает корректно!")
|
||
|
||
|
||
def test_memory_usage():
|
||
"""Тестирование использования памяти"""
|
||
print("\n" + "=" * 60)
|
||
print("Тестирование использования памяти...")
|
||
print("=" * 60)
|
||
|
||
import os
|
||
|
||
import psutil
|
||
|
||
# Получаем текущее использование памяти
|
||
process = psutil.Process(os.getpid())
|
||
memory_before = process.memory_info().rss / 1024 / 1024 # в MB
|
||
|
||
print(f"Память до создания моделей: {memory_before:.2f} MB")
|
||
|
||
# Создаем несколько моделей
|
||
models = []
|
||
for i in range(3):
|
||
gan = create_image_gan(use_cuda=False)
|
||
models.append(gan)
|
||
|
||
# Делаем тестовый проход
|
||
batch_size = 1
|
||
yandex_image = torch.randn(batch_size, 3, 700, 700)
|
||
real_google_image = torch.randn(batch_size, 3, 700, 700)
|
||
|
||
with torch.no_grad():
|
||
_ = gan(yandex_image)
|
||
_ = gan.generator_step(yandex_image, real_google_image)
|
||
|
||
memory_after = process.memory_info().rss / 1024 / 1024 # в MB
|
||
memory_used = memory_after - memory_before
|
||
|
||
print(f"Память после создания моделей: {memory_after:.2f} MB")
|
||
print(f"Использовано памяти: {memory_used:.2f} MB")
|
||
|
||
# Очищаем модели
|
||
del models
|
||
import gc
|
||
|
||
gc.collect()
|
||
|
||
memory_final = process.memory_info().rss / 1024 / 1024
|
||
print(f"Память после очистки: {memory_final:.2f} MB")
|
||
|
||
print("✓ Тестирование памяти завершено!")
|
||
|
||
|
||
def main():
|
||
"""Основная функция тестирования"""
|
||
print("Начало тестирования GAN архитектуры для преобразования Yandex → Google")
|
||
print("Размер изображения: 700x700 пикселей")
|
||
print("=" * 60)
|
||
|
||
try:
|
||
# Запускаем все тесты
|
||
test_generator()
|
||
test_discriminator()
|
||
test_gan_model()
|
||
test_factory_function()
|
||
test_weights_initialization()
|
||
test_memory_usage()
|
||
|
||
print("\n" + "=" * 60)
|
||
print("ВСЕ ТЕСТЫ ПРОЙДЕНЫ УСПЕШНО! 🎉")
|
||
print("=" * 60)
|
||
print("\nАрхитектура GAN готова к использованию для преобразования")
|
||
print("изображений из стиля Yandex в стиль Google.")
|
||
print("\nОсновные характеристики:")
|
||
print(" • Генератор: U-Net архитектура")
|
||
print(" • Дискриминатор: PatchGAN (43x43 патчей)")
|
||
print(" • Размер входных/выходных изображений: 700x700")
|
||
print(" • Поддержка режимов: vanilla, lsgan")
|
||
print(" • L1 регуляризация для сохранения структуры")
|
||
|
||
except Exception as e:
|
||
print(f"\n❌ Ошибка при тестировании: {e}")
|
||
import traceback
|
||
|
||
traceback.print_exc()
|
||
return 1
|
||
|
||
return 0
|
||
|
||
|
||
if __name__ == "__main__":
|
||
exit_code = main()
|
||
sys.exit(exit_code)
|