feat: add models

This commit is contained in:
2026-02-20 16:52:02 +03:00
parent 6040f3b253
commit 0cc210968f
11 changed files with 4488 additions and 48 deletions

342
models/GAN/test_trainer.py Normal file
View File

@@ -0,0 +1,342 @@
"""
Тестовый скрипт для проверки GAN trainer.
"""
import sys
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
# Добавляем путь к модулям
sys.path.append(str(Path(__file__).parent.parent.parent))
from models.GAN.gan import create_image_gan
from models.GAN.trainer import GANTrainer
class SimpleDataset(Dataset):
"""Простой датасет для тестирования."""
def __init__(self, num_samples=100, image_size=(256, 256)):
self.num_samples = num_samples
self.image_size = image_size
def __len__(self):
return self.num_samples
def __getitem__(self, idx):
# Создаем случайные изображения для тестирования
yandex_img = torch.randn(3, self.image_size[0], self.image_size[1])
google_img = torch.randn(3, self.image_size[0], self.image_size[1])
return {"yandex_img": yandex_img, "google_img": google_img}
def test_gan_model():
"""Тестирование GAN модели."""
print("Тестирование GAN модели...")
# Создаем модель
model = create_image_gan(
input_channels=3,
output_channels=3,
gan_mode="vanilla",
lambda_L1=100.0,
use_cuda=False, # Используем CPU для тестирования
)
# Тестируем forward pass
batch_size = 2
image_size = (256, 256)
yandex_input = torch.randn(batch_size, 3, *image_size)
with torch.no_grad():
output = model(yandex_input)
print(f"Входной размер: {yandex_input.shape}")
print(f"Выходной размер: {output.shape}")
print(f"Диапазон выходных значений: [{output.min():.3f}, {output.max():.3f}]")
# Проверяем, что выход в диапазоне [-1, 1] (из-за Tanh)
assert output.min() >= -1.0 and output.max() <= 1.0, "Выход не в диапазоне [-1, 1]"
print("✓ Forward pass работает корректно")
return model
def test_generator_step():
"""Тестирование шага генератора."""
print("\nТестирование шага генератора...")
model = create_image_gan(use_cuda=False)
model.train_mode()
batch_size = 2
yandex_img = torch.randn(batch_size, 3, 256, 256)
google_img = torch.randn(batch_size, 3, 256, 256)
# Тестируем generator_step
total_loss, gan_loss, l1_loss = model.generator_step(yandex_img, google_img)
print(f"Total loss: {total_loss.item():.6f}")
print(f"GAN loss: {gan_loss.item():.6f}")
print(f"L1 loss: {l1_loss.item():.6f}")
assert total_loss.item() > 0, "Потери должны быть положительными"
print("✓ Шаг генератора работает корректно")
def test_discriminator_step():
"""Тестирование шага дискриминатора."""
print("\nТестирование шага дискриминатора...")
model = create_image_gan(use_cuda=False)
model.train_mode()
batch_size = 2
yandex_img = torch.randn(batch_size, 3, 256, 256)
google_img = torch.randn(batch_size, 3, 256, 256)
# Генерируем fake изображение
with torch.no_grad():
fake_google_img = model.generator(yandex_img)
# Тестируем discriminator_step
total_loss, real_loss, fake_loss = model.discriminator_step(
yandex_img, google_img, fake_google_img
)
print(f"Total loss: {total_loss.item():.6f}")
print(f"Real loss: {real_loss.item():.6f}")
print(f"Fake loss: {fake_loss.item():.6f}")
assert total_loss.item() > 0, "Потери должны быть положительными"
print("✓ Шаг дискриминатора работает корректно")
def test_trainer_initialization():
"""Тестирование инициализации тренера."""
print("\nТестирование инициализации тренера...")
# Создаем модель
model = create_image_gan(use_cuda=False)
# Создаем даталоадеры
train_dataset = SimpleDataset(num_samples=50)
val_dataset = SimpleDataset(num_samples=10)
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=2, shuffle=False)
# Конфигурация
config = {
"learning_rate": 2e-4,
"beta1": 0.5,
"beta2": 0.999,
"output_dir": "test_runs/gan",
"early_stopping_patience": 10,
}
# Создаем тренер
device = torch.device("cpu")
trainer = GANTrainer(
model=model,
train_loader=train_loader,
val_loader=val_loader,
device=device,
config=config,
)
print(f"Тренер создан успешно")
print(f"Оптимизатор генератора: {type(trainer.optimizer_G).__name__}")
print(f"Оптимизатор дискриминатора: {type(trainer.optimizer_D).__name__}")
print(f"Выходная директория: {trainer.output_dir}")
assert trainer.output_dir.exists(), "Выходная директория не создана"
print("✓ Тренер инициализирован корректно")
return trainer, train_loader, val_loader
def test_train_epoch():
"""Тестирование одной эпохи обучения."""
print("\nТестирование одной эпохи обучения...")
model = create_image_gan(use_cuda=False)
train_dataset = SimpleDataset(num_samples=20)
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
val_loader = DataLoader(SimpleDataset(num_samples=5), batch_size=2, shuffle=False)
config = {
"learning_rate": 2e-4,
"beta1": 0.5,
"beta2": 0.999,
"output_dir": "test_runs/gan",
}
device = torch.device("cpu")
trainer = GANTrainer(
model=model,
train_loader=train_loader,
val_loader=val_loader,
device=device,
config=config,
)
# Запускаем одну эпоху обучения
avg_g_loss, avg_d_loss = trainer.train_epoch()
print(f"Средние потери за эпоху:")
print(f" Генератор: {avg_g_loss:.6f}")
print(f" Дискриминатор: {avg_d_loss:.6f}")
assert avg_g_loss > 0, "Потери генератора должны быть положительными"
assert avg_d_loss > 0, "Потери дискриминатора должны быть положительными"
print("✓ Эпоха обучения завершена успешно")
def test_validation():
"""Тестирование валидации."""
print("\nТестирование валидации...")
model = create_image_gan(use_cuda=False)
train_loader = DataLoader(SimpleDataset(num_samples=10), batch_size=2, shuffle=True)
val_loader = DataLoader(SimpleDataset(num_samples=5), batch_size=2, shuffle=False)
config = {
"learning_rate": 2e-4,
"beta1": 0.5,
"beta2": 0.999,
"output_dir": "test_runs/gan",
}
device = torch.device("cpu")
trainer = GANTrainer(
model=model,
train_loader=train_loader,
val_loader=val_loader,
device=device,
config=config,
)
# Запускаем валидацию
val_g_loss, val_d_loss = trainer.validate()
print(f"Потери на валидации:")
print(f" Генератор: {val_g_loss:.6f}")
print(f" Дискриминатор: {val_d_loss:.6f}")
assert val_g_loss > 0, "Потери генератора должны быть положительными"
assert val_d_loss > 0, "Потери дискриминатора должны быть положительными"
print("✓ Валидация завершена успешно")
def test_checkpoint_saving():
"""Тестирование сохранения чекпоинтов."""
print("\nТестирование сохранения чекпоинтов...")
model = create_image_gan(use_cuda=False)
train_loader = DataLoader(SimpleDataset(num_samples=10), batch_size=2, shuffle=True)
val_loader = DataLoader(SimpleDataset(num_samples=5), batch_size=2, shuffle=False)
config = {
"learning_rate": 2e-4,
"beta1": 0.5,
"beta2": 0.999,
"output_dir": "test_runs/gan_checkpoint",
}
device = torch.device("cpu")
trainer = GANTrainer(
model=model,
train_loader=train_loader,
val_loader=val_loader,
device=device,
config=config,
)
# Сохраняем чекпоинт
trainer.save_checkpoint(is_best=True)
# Проверяем, что файлы созданы
checkpoint_files = list(trainer.output_dir.glob("*.pth"))
print(f"Создано файлов чекпоинтов: {len(checkpoint_files)}")
for file in checkpoint_files:
print(f" - {file.name}")
assert len(checkpoint_files) > 0, "Файлы чекпоинтов не созданы"
print("✓ Чекпоинты сохранены успешно")
# Тестируем загрузку чекпоинта
checkpoint_path = checkpoint_files[0]
print(f"\nТестируем загрузку чекпоинта: {checkpoint_path}")
# Создаем новую модель и тренер
new_model = create_image_gan(use_cuda=False)
new_trainer = GANTrainer(
model=new_model,
train_loader=train_loader,
val_loader=val_loader,
device=device,
config=config,
)
# Загружаем чекпоинт
new_trainer.load_checkpoint(str(checkpoint_path))
print(f"Загружен чекпоинт эпохи: {new_trainer.current_epoch + 1}")
print("✓ Чекпоинт загружен успешно")
def main():
"""Основная функция тестирования."""
print("=" * 60)
print("Начало тестирования GAN trainer")
print("=" * 60)
try:
# Запускаем все тесты
test_gan_model()
test_generator_step()
test_discriminator_step()
test_trainer_initialization()
test_train_epoch()
test_validation()
test_checkpoint_saving()
print("\n" + "=" * 60)
print("Все тесты пройдены успешно! ✓")
print("=" * 60)
except Exception as e:
print(f"\nОшибка при тестировании: {e}")
import traceback
traceback.print_exc()
return 1
return 0
if __name__ == "__main__":
# Очищаем тестовые директории
import shutil
test_dirs = ["test_runs/gan", "test_runs/gan_checkpoint"]
for dir_path in test_dirs:
if Path(dir_path).exists():
shutil.rmtree(dir_path)
# Запускаем тесты
exit_code = main()
# Очищаем после тестов
for dir_path in test_dirs:
if Path(dir_path).exists():
shutil.rmtree(dir_path)
exit(exit_code)