343 lines
12 KiB
Python
343 lines
12 KiB
Python
"""
|
||
Тестовый скрипт для проверки GAN trainer.
|
||
"""
|
||
|
||
import sys
|
||
from pathlib import Path
|
||
|
||
import numpy as np
|
||
import torch
|
||
import torch.nn as nn
|
||
from torch.utils.data import DataLoader, Dataset
|
||
|
||
# Добавляем путь к модулям
|
||
sys.path.append(str(Path(__file__).parent.parent.parent))
|
||
|
||
from models.GAN.gan import create_image_gan
|
||
from models.GAN.trainer import GANTrainer
|
||
|
||
|
||
class SimpleDataset(Dataset):
|
||
"""Простой датасет для тестирования."""
|
||
|
||
def __init__(self, num_samples=100, image_size=(256, 256)):
|
||
self.num_samples = num_samples
|
||
self.image_size = image_size
|
||
|
||
def __len__(self):
|
||
return self.num_samples
|
||
|
||
def __getitem__(self, idx):
|
||
# Создаем случайные изображения для тестирования
|
||
yandex_img = torch.randn(3, self.image_size[0], self.image_size[1])
|
||
google_img = torch.randn(3, self.image_size[0], self.image_size[1])
|
||
|
||
return {"yandex_img": yandex_img, "google_img": google_img}
|
||
|
||
|
||
def test_gan_model():
|
||
"""Тестирование GAN модели."""
|
||
print("Тестирование GAN модели...")
|
||
|
||
# Создаем модель
|
||
model = create_image_gan(
|
||
input_channels=3,
|
||
output_channels=3,
|
||
gan_mode="vanilla",
|
||
lambda_L1=100.0,
|
||
use_cuda=False, # Используем CPU для тестирования
|
||
)
|
||
|
||
# Тестируем forward pass
|
||
batch_size = 2
|
||
image_size = (256, 256)
|
||
yandex_input = torch.randn(batch_size, 3, *image_size)
|
||
|
||
with torch.no_grad():
|
||
output = model(yandex_input)
|
||
|
||
print(f"Входной размер: {yandex_input.shape}")
|
||
print(f"Выходной размер: {output.shape}")
|
||
print(f"Диапазон выходных значений: [{output.min():.3f}, {output.max():.3f}]")
|
||
|
||
# Проверяем, что выход в диапазоне [-1, 1] (из-за Tanh)
|
||
assert output.min() >= -1.0 and output.max() <= 1.0, "Выход не в диапазоне [-1, 1]"
|
||
print("✓ Forward pass работает корректно")
|
||
|
||
return model
|
||
|
||
|
||
def test_generator_step():
|
||
"""Тестирование шага генератора."""
|
||
print("\nТестирование шага генератора...")
|
||
|
||
model = create_image_gan(use_cuda=False)
|
||
model.train_mode()
|
||
|
||
batch_size = 2
|
||
yandex_img = torch.randn(batch_size, 3, 256, 256)
|
||
google_img = torch.randn(batch_size, 3, 256, 256)
|
||
|
||
# Тестируем generator_step
|
||
total_loss, gan_loss, l1_loss = model.generator_step(yandex_img, google_img)
|
||
|
||
print(f"Total loss: {total_loss.item():.6f}")
|
||
print(f"GAN loss: {gan_loss.item():.6f}")
|
||
print(f"L1 loss: {l1_loss.item():.6f}")
|
||
|
||
assert total_loss.item() > 0, "Потери должны быть положительными"
|
||
print("✓ Шаг генератора работает корректно")
|
||
|
||
|
||
def test_discriminator_step():
|
||
"""Тестирование шага дискриминатора."""
|
||
print("\nТестирование шага дискриминатора...")
|
||
|
||
model = create_image_gan(use_cuda=False)
|
||
model.train_mode()
|
||
|
||
batch_size = 2
|
||
yandex_img = torch.randn(batch_size, 3, 256, 256)
|
||
google_img = torch.randn(batch_size, 3, 256, 256)
|
||
|
||
# Генерируем fake изображение
|
||
with torch.no_grad():
|
||
fake_google_img = model.generator(yandex_img)
|
||
|
||
# Тестируем discriminator_step
|
||
total_loss, real_loss, fake_loss = model.discriminator_step(
|
||
yandex_img, google_img, fake_google_img
|
||
)
|
||
|
||
print(f"Total loss: {total_loss.item():.6f}")
|
||
print(f"Real loss: {real_loss.item():.6f}")
|
||
print(f"Fake loss: {fake_loss.item():.6f}")
|
||
|
||
assert total_loss.item() > 0, "Потери должны быть положительными"
|
||
print("✓ Шаг дискриминатора работает корректно")
|
||
|
||
|
||
def test_trainer_initialization():
|
||
"""Тестирование инициализации тренера."""
|
||
print("\nТестирование инициализации тренера...")
|
||
|
||
# Создаем модель
|
||
model = create_image_gan(use_cuda=False)
|
||
|
||
# Создаем даталоадеры
|
||
train_dataset = SimpleDataset(num_samples=50)
|
||
val_dataset = SimpleDataset(num_samples=10)
|
||
|
||
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
|
||
val_loader = DataLoader(val_dataset, batch_size=2, shuffle=False)
|
||
|
||
# Конфигурация
|
||
config = {
|
||
"learning_rate": 2e-4,
|
||
"beta1": 0.5,
|
||
"beta2": 0.999,
|
||
"output_dir": "test_runs/gan",
|
||
"early_stopping_patience": 10,
|
||
}
|
||
|
||
# Создаем тренер
|
||
device = torch.device("cpu")
|
||
trainer = GANTrainer(
|
||
model=model,
|
||
train_loader=train_loader,
|
||
val_loader=val_loader,
|
||
device=device,
|
||
config=config,
|
||
)
|
||
|
||
print(f"Тренер создан успешно")
|
||
print(f"Оптимизатор генератора: {type(trainer.optimizer_G).__name__}")
|
||
print(f"Оптимизатор дискриминатора: {type(trainer.optimizer_D).__name__}")
|
||
print(f"Выходная директория: {trainer.output_dir}")
|
||
|
||
assert trainer.output_dir.exists(), "Выходная директория не создана"
|
||
print("✓ Тренер инициализирован корректно")
|
||
|
||
return trainer, train_loader, val_loader
|
||
|
||
|
||
def test_train_epoch():
|
||
"""Тестирование одной эпохи обучения."""
|
||
print("\nТестирование одной эпохи обучения...")
|
||
|
||
model = create_image_gan(use_cuda=False)
|
||
train_dataset = SimpleDataset(num_samples=20)
|
||
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
|
||
val_loader = DataLoader(SimpleDataset(num_samples=5), batch_size=2, shuffle=False)
|
||
|
||
config = {
|
||
"learning_rate": 2e-4,
|
||
"beta1": 0.5,
|
||
"beta2": 0.999,
|
||
"output_dir": "test_runs/gan",
|
||
}
|
||
|
||
device = torch.device("cpu")
|
||
trainer = GANTrainer(
|
||
model=model,
|
||
train_loader=train_loader,
|
||
val_loader=val_loader,
|
||
device=device,
|
||
config=config,
|
||
)
|
||
|
||
# Запускаем одну эпоху обучения
|
||
avg_g_loss, avg_d_loss = trainer.train_epoch()
|
||
|
||
print(f"Средние потери за эпоху:")
|
||
print(f" Генератор: {avg_g_loss:.6f}")
|
||
print(f" Дискриминатор: {avg_d_loss:.6f}")
|
||
|
||
assert avg_g_loss > 0, "Потери генератора должны быть положительными"
|
||
assert avg_d_loss > 0, "Потери дискриминатора должны быть положительными"
|
||
print("✓ Эпоха обучения завершена успешно")
|
||
|
||
|
||
def test_validation():
|
||
"""Тестирование валидации."""
|
||
print("\nТестирование валидации...")
|
||
|
||
model = create_image_gan(use_cuda=False)
|
||
train_loader = DataLoader(SimpleDataset(num_samples=10), batch_size=2, shuffle=True)
|
||
val_loader = DataLoader(SimpleDataset(num_samples=5), batch_size=2, shuffle=False)
|
||
|
||
config = {
|
||
"learning_rate": 2e-4,
|
||
"beta1": 0.5,
|
||
"beta2": 0.999,
|
||
"output_dir": "test_runs/gan",
|
||
}
|
||
|
||
device = torch.device("cpu")
|
||
trainer = GANTrainer(
|
||
model=model,
|
||
train_loader=train_loader,
|
||
val_loader=val_loader,
|
||
device=device,
|
||
config=config,
|
||
)
|
||
|
||
# Запускаем валидацию
|
||
val_g_loss, val_d_loss = trainer.validate()
|
||
|
||
print(f"Потери на валидации:")
|
||
print(f" Генератор: {val_g_loss:.6f}")
|
||
print(f" Дискриминатор: {val_d_loss:.6f}")
|
||
|
||
assert val_g_loss > 0, "Потери генератора должны быть положительными"
|
||
assert val_d_loss > 0, "Потери дискриминатора должны быть положительными"
|
||
print("✓ Валидация завершена успешно")
|
||
|
||
|
||
def test_checkpoint_saving():
|
||
"""Тестирование сохранения чекпоинтов."""
|
||
print("\nТестирование сохранения чекпоинтов...")
|
||
|
||
model = create_image_gan(use_cuda=False)
|
||
train_loader = DataLoader(SimpleDataset(num_samples=10), batch_size=2, shuffle=True)
|
||
val_loader = DataLoader(SimpleDataset(num_samples=5), batch_size=2, shuffle=False)
|
||
|
||
config = {
|
||
"learning_rate": 2e-4,
|
||
"beta1": 0.5,
|
||
"beta2": 0.999,
|
||
"output_dir": "test_runs/gan_checkpoint",
|
||
}
|
||
|
||
device = torch.device("cpu")
|
||
trainer = GANTrainer(
|
||
model=model,
|
||
train_loader=train_loader,
|
||
val_loader=val_loader,
|
||
device=device,
|
||
config=config,
|
||
)
|
||
|
||
# Сохраняем чекпоинт
|
||
trainer.save_checkpoint(is_best=True)
|
||
|
||
# Проверяем, что файлы созданы
|
||
checkpoint_files = list(trainer.output_dir.glob("*.pth"))
|
||
print(f"Создано файлов чекпоинтов: {len(checkpoint_files)}")
|
||
|
||
for file in checkpoint_files:
|
||
print(f" - {file.name}")
|
||
|
||
assert len(checkpoint_files) > 0, "Файлы чекпоинтов не созданы"
|
||
print("✓ Чекпоинты сохранены успешно")
|
||
|
||
# Тестируем загрузку чекпоинта
|
||
checkpoint_path = checkpoint_files[0]
|
||
print(f"\nТестируем загрузку чекпоинта: {checkpoint_path}")
|
||
|
||
# Создаем новую модель и тренер
|
||
new_model = create_image_gan(use_cuda=False)
|
||
new_trainer = GANTrainer(
|
||
model=new_model,
|
||
train_loader=train_loader,
|
||
val_loader=val_loader,
|
||
device=device,
|
||
config=config,
|
||
)
|
||
|
||
# Загружаем чекпоинт
|
||
new_trainer.load_checkpoint(str(checkpoint_path))
|
||
|
||
print(f"Загружен чекпоинт эпохи: {new_trainer.current_epoch + 1}")
|
||
print("✓ Чекпоинт загружен успешно")
|
||
|
||
|
||
def main():
|
||
"""Основная функция тестирования."""
|
||
print("=" * 60)
|
||
print("Начало тестирования GAN trainer")
|
||
print("=" * 60)
|
||
|
||
try:
|
||
# Запускаем все тесты
|
||
test_gan_model()
|
||
test_generator_step()
|
||
test_discriminator_step()
|
||
test_trainer_initialization()
|
||
test_train_epoch()
|
||
test_validation()
|
||
test_checkpoint_saving()
|
||
|
||
print("\n" + "=" * 60)
|
||
print("Все тесты пройдены успешно! ✓")
|
||
print("=" * 60)
|
||
|
||
except Exception as e:
|
||
print(f"\nОшибка при тестировании: {e}")
|
||
import traceback
|
||
|
||
traceback.print_exc()
|
||
return 1
|
||
|
||
return 0
|
||
|
||
|
||
if __name__ == "__main__":
|
||
# Очищаем тестовые директории
|
||
import shutil
|
||
|
||
test_dirs = ["test_runs/gan", "test_runs/gan_checkpoint"]
|
||
for dir_path in test_dirs:
|
||
if Path(dir_path).exists():
|
||
shutil.rmtree(dir_path)
|
||
|
||
# Запускаем тесты
|
||
exit_code = main()
|
||
|
||
# Очищаем после тестов
|
||
for dir_path in test_dirs:
|
||
if Path(dir_path).exists():
|
||
shutil.rmtree(dir_path)
|
||
|
||
exit(exit_code)
|