feat: add models
This commit is contained in:
342
models/GAN/test_trainer.py
Normal file
342
models/GAN/test_trainer.py
Normal file
@@ -0,0 +1,342 @@
|
||||
"""
|
||||
Тестовый скрипт для проверки GAN trainer.
|
||||
"""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
|
||||
# Добавляем путь к модулям
|
||||
sys.path.append(str(Path(__file__).parent.parent.parent))
|
||||
|
||||
from models.GAN.gan import create_image_gan
|
||||
from models.GAN.trainer import GANTrainer
|
||||
|
||||
|
||||
class SimpleDataset(Dataset):
|
||||
"""Простой датасет для тестирования."""
|
||||
|
||||
def __init__(self, num_samples=100, image_size=(256, 256)):
|
||||
self.num_samples = num_samples
|
||||
self.image_size = image_size
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samples
|
||||
|
||||
def __getitem__(self, idx):
|
||||
# Создаем случайные изображения для тестирования
|
||||
yandex_img = torch.randn(3, self.image_size[0], self.image_size[1])
|
||||
google_img = torch.randn(3, self.image_size[0], self.image_size[1])
|
||||
|
||||
return {"yandex_img": yandex_img, "google_img": google_img}
|
||||
|
||||
|
||||
def test_gan_model():
|
||||
"""Тестирование GAN модели."""
|
||||
print("Тестирование GAN модели...")
|
||||
|
||||
# Создаем модель
|
||||
model = create_image_gan(
|
||||
input_channels=3,
|
||||
output_channels=3,
|
||||
gan_mode="vanilla",
|
||||
lambda_L1=100.0,
|
||||
use_cuda=False, # Используем CPU для тестирования
|
||||
)
|
||||
|
||||
# Тестируем forward pass
|
||||
batch_size = 2
|
||||
image_size = (256, 256)
|
||||
yandex_input = torch.randn(batch_size, 3, *image_size)
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(yandex_input)
|
||||
|
||||
print(f"Входной размер: {yandex_input.shape}")
|
||||
print(f"Выходной размер: {output.shape}")
|
||||
print(f"Диапазон выходных значений: [{output.min():.3f}, {output.max():.3f}]")
|
||||
|
||||
# Проверяем, что выход в диапазоне [-1, 1] (из-за Tanh)
|
||||
assert output.min() >= -1.0 and output.max() <= 1.0, "Выход не в диапазоне [-1, 1]"
|
||||
print("✓ Forward pass работает корректно")
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def test_generator_step():
|
||||
"""Тестирование шага генератора."""
|
||||
print("\nТестирование шага генератора...")
|
||||
|
||||
model = create_image_gan(use_cuda=False)
|
||||
model.train_mode()
|
||||
|
||||
batch_size = 2
|
||||
yandex_img = torch.randn(batch_size, 3, 256, 256)
|
||||
google_img = torch.randn(batch_size, 3, 256, 256)
|
||||
|
||||
# Тестируем generator_step
|
||||
total_loss, gan_loss, l1_loss = model.generator_step(yandex_img, google_img)
|
||||
|
||||
print(f"Total loss: {total_loss.item():.6f}")
|
||||
print(f"GAN loss: {gan_loss.item():.6f}")
|
||||
print(f"L1 loss: {l1_loss.item():.6f}")
|
||||
|
||||
assert total_loss.item() > 0, "Потери должны быть положительными"
|
||||
print("✓ Шаг генератора работает корректно")
|
||||
|
||||
|
||||
def test_discriminator_step():
|
||||
"""Тестирование шага дискриминатора."""
|
||||
print("\nТестирование шага дискриминатора...")
|
||||
|
||||
model = create_image_gan(use_cuda=False)
|
||||
model.train_mode()
|
||||
|
||||
batch_size = 2
|
||||
yandex_img = torch.randn(batch_size, 3, 256, 256)
|
||||
google_img = torch.randn(batch_size, 3, 256, 256)
|
||||
|
||||
# Генерируем fake изображение
|
||||
with torch.no_grad():
|
||||
fake_google_img = model.generator(yandex_img)
|
||||
|
||||
# Тестируем discriminator_step
|
||||
total_loss, real_loss, fake_loss = model.discriminator_step(
|
||||
yandex_img, google_img, fake_google_img
|
||||
)
|
||||
|
||||
print(f"Total loss: {total_loss.item():.6f}")
|
||||
print(f"Real loss: {real_loss.item():.6f}")
|
||||
print(f"Fake loss: {fake_loss.item():.6f}")
|
||||
|
||||
assert total_loss.item() > 0, "Потери должны быть положительными"
|
||||
print("✓ Шаг дискриминатора работает корректно")
|
||||
|
||||
|
||||
def test_trainer_initialization():
|
||||
"""Тестирование инициализации тренера."""
|
||||
print("\nТестирование инициализации тренера...")
|
||||
|
||||
# Создаем модель
|
||||
model = create_image_gan(use_cuda=False)
|
||||
|
||||
# Создаем даталоадеры
|
||||
train_dataset = SimpleDataset(num_samples=50)
|
||||
val_dataset = SimpleDataset(num_samples=10)
|
||||
|
||||
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
|
||||
val_loader = DataLoader(val_dataset, batch_size=2, shuffle=False)
|
||||
|
||||
# Конфигурация
|
||||
config = {
|
||||
"learning_rate": 2e-4,
|
||||
"beta1": 0.5,
|
||||
"beta2": 0.999,
|
||||
"output_dir": "test_runs/gan",
|
||||
"early_stopping_patience": 10,
|
||||
}
|
||||
|
||||
# Создаем тренер
|
||||
device = torch.device("cpu")
|
||||
trainer = GANTrainer(
|
||||
model=model,
|
||||
train_loader=train_loader,
|
||||
val_loader=val_loader,
|
||||
device=device,
|
||||
config=config,
|
||||
)
|
||||
|
||||
print(f"Тренер создан успешно")
|
||||
print(f"Оптимизатор генератора: {type(trainer.optimizer_G).__name__}")
|
||||
print(f"Оптимизатор дискриминатора: {type(trainer.optimizer_D).__name__}")
|
||||
print(f"Выходная директория: {trainer.output_dir}")
|
||||
|
||||
assert trainer.output_dir.exists(), "Выходная директория не создана"
|
||||
print("✓ Тренер инициализирован корректно")
|
||||
|
||||
return trainer, train_loader, val_loader
|
||||
|
||||
|
||||
def test_train_epoch():
|
||||
"""Тестирование одной эпохи обучения."""
|
||||
print("\nТестирование одной эпохи обучения...")
|
||||
|
||||
model = create_image_gan(use_cuda=False)
|
||||
train_dataset = SimpleDataset(num_samples=20)
|
||||
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
|
||||
val_loader = DataLoader(SimpleDataset(num_samples=5), batch_size=2, shuffle=False)
|
||||
|
||||
config = {
|
||||
"learning_rate": 2e-4,
|
||||
"beta1": 0.5,
|
||||
"beta2": 0.999,
|
||||
"output_dir": "test_runs/gan",
|
||||
}
|
||||
|
||||
device = torch.device("cpu")
|
||||
trainer = GANTrainer(
|
||||
model=model,
|
||||
train_loader=train_loader,
|
||||
val_loader=val_loader,
|
||||
device=device,
|
||||
config=config,
|
||||
)
|
||||
|
||||
# Запускаем одну эпоху обучения
|
||||
avg_g_loss, avg_d_loss = trainer.train_epoch()
|
||||
|
||||
print(f"Средние потери за эпоху:")
|
||||
print(f" Генератор: {avg_g_loss:.6f}")
|
||||
print(f" Дискриминатор: {avg_d_loss:.6f}")
|
||||
|
||||
assert avg_g_loss > 0, "Потери генератора должны быть положительными"
|
||||
assert avg_d_loss > 0, "Потери дискриминатора должны быть положительными"
|
||||
print("✓ Эпоха обучения завершена успешно")
|
||||
|
||||
|
||||
def test_validation():
|
||||
"""Тестирование валидации."""
|
||||
print("\nТестирование валидации...")
|
||||
|
||||
model = create_image_gan(use_cuda=False)
|
||||
train_loader = DataLoader(SimpleDataset(num_samples=10), batch_size=2, shuffle=True)
|
||||
val_loader = DataLoader(SimpleDataset(num_samples=5), batch_size=2, shuffle=False)
|
||||
|
||||
config = {
|
||||
"learning_rate": 2e-4,
|
||||
"beta1": 0.5,
|
||||
"beta2": 0.999,
|
||||
"output_dir": "test_runs/gan",
|
||||
}
|
||||
|
||||
device = torch.device("cpu")
|
||||
trainer = GANTrainer(
|
||||
model=model,
|
||||
train_loader=train_loader,
|
||||
val_loader=val_loader,
|
||||
device=device,
|
||||
config=config,
|
||||
)
|
||||
|
||||
# Запускаем валидацию
|
||||
val_g_loss, val_d_loss = trainer.validate()
|
||||
|
||||
print(f"Потери на валидации:")
|
||||
print(f" Генератор: {val_g_loss:.6f}")
|
||||
print(f" Дискриминатор: {val_d_loss:.6f}")
|
||||
|
||||
assert val_g_loss > 0, "Потери генератора должны быть положительными"
|
||||
assert val_d_loss > 0, "Потери дискриминатора должны быть положительными"
|
||||
print("✓ Валидация завершена успешно")
|
||||
|
||||
|
||||
def test_checkpoint_saving():
|
||||
"""Тестирование сохранения чекпоинтов."""
|
||||
print("\nТестирование сохранения чекпоинтов...")
|
||||
|
||||
model = create_image_gan(use_cuda=False)
|
||||
train_loader = DataLoader(SimpleDataset(num_samples=10), batch_size=2, shuffle=True)
|
||||
val_loader = DataLoader(SimpleDataset(num_samples=5), batch_size=2, shuffle=False)
|
||||
|
||||
config = {
|
||||
"learning_rate": 2e-4,
|
||||
"beta1": 0.5,
|
||||
"beta2": 0.999,
|
||||
"output_dir": "test_runs/gan_checkpoint",
|
||||
}
|
||||
|
||||
device = torch.device("cpu")
|
||||
trainer = GANTrainer(
|
||||
model=model,
|
||||
train_loader=train_loader,
|
||||
val_loader=val_loader,
|
||||
device=device,
|
||||
config=config,
|
||||
)
|
||||
|
||||
# Сохраняем чекпоинт
|
||||
trainer.save_checkpoint(is_best=True)
|
||||
|
||||
# Проверяем, что файлы созданы
|
||||
checkpoint_files = list(trainer.output_dir.glob("*.pth"))
|
||||
print(f"Создано файлов чекпоинтов: {len(checkpoint_files)}")
|
||||
|
||||
for file in checkpoint_files:
|
||||
print(f" - {file.name}")
|
||||
|
||||
assert len(checkpoint_files) > 0, "Файлы чекпоинтов не созданы"
|
||||
print("✓ Чекпоинты сохранены успешно")
|
||||
|
||||
# Тестируем загрузку чекпоинта
|
||||
checkpoint_path = checkpoint_files[0]
|
||||
print(f"\nТестируем загрузку чекпоинта: {checkpoint_path}")
|
||||
|
||||
# Создаем новую модель и тренер
|
||||
new_model = create_image_gan(use_cuda=False)
|
||||
new_trainer = GANTrainer(
|
||||
model=new_model,
|
||||
train_loader=train_loader,
|
||||
val_loader=val_loader,
|
||||
device=device,
|
||||
config=config,
|
||||
)
|
||||
|
||||
# Загружаем чекпоинт
|
||||
new_trainer.load_checkpoint(str(checkpoint_path))
|
||||
|
||||
print(f"Загружен чекпоинт эпохи: {new_trainer.current_epoch + 1}")
|
||||
print("✓ Чекпоинт загружен успешно")
|
||||
|
||||
|
||||
def main():
|
||||
"""Основная функция тестирования."""
|
||||
print("=" * 60)
|
||||
print("Начало тестирования GAN trainer")
|
||||
print("=" * 60)
|
||||
|
||||
try:
|
||||
# Запускаем все тесты
|
||||
test_gan_model()
|
||||
test_generator_step()
|
||||
test_discriminator_step()
|
||||
test_trainer_initialization()
|
||||
test_train_epoch()
|
||||
test_validation()
|
||||
test_checkpoint_saving()
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("Все тесты пройдены успешно! ✓")
|
||||
print("=" * 60)
|
||||
|
||||
except Exception as e:
|
||||
print(f"\nОшибка при тестировании: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
return 1
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Очищаем тестовые директории
|
||||
import shutil
|
||||
|
||||
test_dirs = ["test_runs/gan", "test_runs/gan_checkpoint"]
|
||||
for dir_path in test_dirs:
|
||||
if Path(dir_path).exists():
|
||||
shutil.rmtree(dir_path)
|
||||
|
||||
# Запускаем тесты
|
||||
exit_code = main()
|
||||
|
||||
# Очищаем после тестов
|
||||
for dir_path in test_dirs:
|
||||
if Path(dir_path).exists():
|
||||
shutil.rmtree(dir_path)
|
||||
|
||||
exit(exit_code)
|
||||
Reference in New Issue
Block a user