Files
autopilot/models/GAN/test_trainer.py
2026-02-20 16:52:02 +03:00

343 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
Тестовый скрипт для проверки GAN trainer.
"""
import sys
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
# Добавляем путь к модулям
sys.path.append(str(Path(__file__).parent.parent.parent))
from models.GAN.gan import create_image_gan
from models.GAN.trainer import GANTrainer
class SimpleDataset(Dataset):
"""Простой датасет для тестирования."""
def __init__(self, num_samples=100, image_size=(256, 256)):
self.num_samples = num_samples
self.image_size = image_size
def __len__(self):
return self.num_samples
def __getitem__(self, idx):
# Создаем случайные изображения для тестирования
yandex_img = torch.randn(3, self.image_size[0], self.image_size[1])
google_img = torch.randn(3, self.image_size[0], self.image_size[1])
return {"yandex_img": yandex_img, "google_img": google_img}
def test_gan_model():
"""Тестирование GAN модели."""
print("Тестирование GAN модели...")
# Создаем модель
model = create_image_gan(
input_channels=3,
output_channels=3,
gan_mode="vanilla",
lambda_L1=100.0,
use_cuda=False, # Используем CPU для тестирования
)
# Тестируем forward pass
batch_size = 2
image_size = (256, 256)
yandex_input = torch.randn(batch_size, 3, *image_size)
with torch.no_grad():
output = model(yandex_input)
print(f"Входной размер: {yandex_input.shape}")
print(f"Выходной размер: {output.shape}")
print(f"Диапазон выходных значений: [{output.min():.3f}, {output.max():.3f}]")
# Проверяем, что выход в диапазоне [-1, 1] (из-за Tanh)
assert output.min() >= -1.0 and output.max() <= 1.0, "Выход не в диапазоне [-1, 1]"
print("✓ Forward pass работает корректно")
return model
def test_generator_step():
"""Тестирование шага генератора."""
print("\nТестирование шага генератора...")
model = create_image_gan(use_cuda=False)
model.train_mode()
batch_size = 2
yandex_img = torch.randn(batch_size, 3, 256, 256)
google_img = torch.randn(batch_size, 3, 256, 256)
# Тестируем generator_step
total_loss, gan_loss, l1_loss = model.generator_step(yandex_img, google_img)
print(f"Total loss: {total_loss.item():.6f}")
print(f"GAN loss: {gan_loss.item():.6f}")
print(f"L1 loss: {l1_loss.item():.6f}")
assert total_loss.item() > 0, "Потери должны быть положительными"
print("✓ Шаг генератора работает корректно")
def test_discriminator_step():
"""Тестирование шага дискриминатора."""
print("\nТестирование шага дискриминатора...")
model = create_image_gan(use_cuda=False)
model.train_mode()
batch_size = 2
yandex_img = torch.randn(batch_size, 3, 256, 256)
google_img = torch.randn(batch_size, 3, 256, 256)
# Генерируем fake изображение
with torch.no_grad():
fake_google_img = model.generator(yandex_img)
# Тестируем discriminator_step
total_loss, real_loss, fake_loss = model.discriminator_step(
yandex_img, google_img, fake_google_img
)
print(f"Total loss: {total_loss.item():.6f}")
print(f"Real loss: {real_loss.item():.6f}")
print(f"Fake loss: {fake_loss.item():.6f}")
assert total_loss.item() > 0, "Потери должны быть положительными"
print("✓ Шаг дискриминатора работает корректно")
def test_trainer_initialization():
"""Тестирование инициализации тренера."""
print("\nТестирование инициализации тренера...")
# Создаем модель
model = create_image_gan(use_cuda=False)
# Создаем даталоадеры
train_dataset = SimpleDataset(num_samples=50)
val_dataset = SimpleDataset(num_samples=10)
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=2, shuffle=False)
# Конфигурация
config = {
"learning_rate": 2e-4,
"beta1": 0.5,
"beta2": 0.999,
"output_dir": "test_runs/gan",
"early_stopping_patience": 10,
}
# Создаем тренер
device = torch.device("cpu")
trainer = GANTrainer(
model=model,
train_loader=train_loader,
val_loader=val_loader,
device=device,
config=config,
)
print(f"Тренер создан успешно")
print(f"Оптимизатор генератора: {type(trainer.optimizer_G).__name__}")
print(f"Оптимизатор дискриминатора: {type(trainer.optimizer_D).__name__}")
print(f"Выходная директория: {trainer.output_dir}")
assert trainer.output_dir.exists(), "Выходная директория не создана"
print("✓ Тренер инициализирован корректно")
return trainer, train_loader, val_loader
def test_train_epoch():
"""Тестирование одной эпохи обучения."""
print("\nТестирование одной эпохи обучения...")
model = create_image_gan(use_cuda=False)
train_dataset = SimpleDataset(num_samples=20)
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
val_loader = DataLoader(SimpleDataset(num_samples=5), batch_size=2, shuffle=False)
config = {
"learning_rate": 2e-4,
"beta1": 0.5,
"beta2": 0.999,
"output_dir": "test_runs/gan",
}
device = torch.device("cpu")
trainer = GANTrainer(
model=model,
train_loader=train_loader,
val_loader=val_loader,
device=device,
config=config,
)
# Запускаем одну эпоху обучения
avg_g_loss, avg_d_loss = trainer.train_epoch()
print(f"Средние потери за эпоху:")
print(f" Генератор: {avg_g_loss:.6f}")
print(f" Дискриминатор: {avg_d_loss:.6f}")
assert avg_g_loss > 0, "Потери генератора должны быть положительными"
assert avg_d_loss > 0, "Потери дискриминатора должны быть положительными"
print("✓ Эпоха обучения завершена успешно")
def test_validation():
"""Тестирование валидации."""
print("\nТестирование валидации...")
model = create_image_gan(use_cuda=False)
train_loader = DataLoader(SimpleDataset(num_samples=10), batch_size=2, shuffle=True)
val_loader = DataLoader(SimpleDataset(num_samples=5), batch_size=2, shuffle=False)
config = {
"learning_rate": 2e-4,
"beta1": 0.5,
"beta2": 0.999,
"output_dir": "test_runs/gan",
}
device = torch.device("cpu")
trainer = GANTrainer(
model=model,
train_loader=train_loader,
val_loader=val_loader,
device=device,
config=config,
)
# Запускаем валидацию
val_g_loss, val_d_loss = trainer.validate()
print(f"Потери на валидации:")
print(f" Генератор: {val_g_loss:.6f}")
print(f" Дискриминатор: {val_d_loss:.6f}")
assert val_g_loss > 0, "Потери генератора должны быть положительными"
assert val_d_loss > 0, "Потери дискриминатора должны быть положительными"
print("✓ Валидация завершена успешно")
def test_checkpoint_saving():
"""Тестирование сохранения чекпоинтов."""
print("\nТестирование сохранения чекпоинтов...")
model = create_image_gan(use_cuda=False)
train_loader = DataLoader(SimpleDataset(num_samples=10), batch_size=2, shuffle=True)
val_loader = DataLoader(SimpleDataset(num_samples=5), batch_size=2, shuffle=False)
config = {
"learning_rate": 2e-4,
"beta1": 0.5,
"beta2": 0.999,
"output_dir": "test_runs/gan_checkpoint",
}
device = torch.device("cpu")
trainer = GANTrainer(
model=model,
train_loader=train_loader,
val_loader=val_loader,
device=device,
config=config,
)
# Сохраняем чекпоинт
trainer.save_checkpoint(is_best=True)
# Проверяем, что файлы созданы
checkpoint_files = list(trainer.output_dir.glob("*.pth"))
print(f"Создано файлов чекпоинтов: {len(checkpoint_files)}")
for file in checkpoint_files:
print(f" - {file.name}")
assert len(checkpoint_files) > 0, "Файлы чекпоинтов не созданы"
print("✓ Чекпоинты сохранены успешно")
# Тестируем загрузку чекпоинта
checkpoint_path = checkpoint_files[0]
print(f"\nТестируем загрузку чекпоинта: {checkpoint_path}")
# Создаем новую модель и тренер
new_model = create_image_gan(use_cuda=False)
new_trainer = GANTrainer(
model=new_model,
train_loader=train_loader,
val_loader=val_loader,
device=device,
config=config,
)
# Загружаем чекпоинт
new_trainer.load_checkpoint(str(checkpoint_path))
print(f"Загружен чекпоинт эпохи: {new_trainer.current_epoch + 1}")
print("✓ Чекпоинт загружен успешно")
def main():
"""Основная функция тестирования."""
print("=" * 60)
print("Начало тестирования GAN trainer")
print("=" * 60)
try:
# Запускаем все тесты
test_gan_model()
test_generator_step()
test_discriminator_step()
test_trainer_initialization()
test_train_epoch()
test_validation()
test_checkpoint_saving()
print("\n" + "=" * 60)
print("Все тесты пройдены успешно! ✓")
print("=" * 60)
except Exception as e:
print(f"\nОшибка при тестировании: {e}")
import traceback
traceback.print_exc()
return 1
return 0
if __name__ == "__main__":
# Очищаем тестовые директории
import shutil
test_dirs = ["test_runs/gan", "test_runs/gan_checkpoint"]
for dir_path in test_dirs:
if Path(dir_path).exists():
shutil.rmtree(dir_path)
# Запускаем тесты
exit_code = main()
# Очищаем после тестов
for dir_path in test_dirs:
if Path(dir_path).exists():
shutil.rmtree(dir_path)
exit(exit_code)