feat: add models
This commit is contained in:
347
models/GAN/train_example.py
Normal file
347
models/GAN/train_example.py
Normal file
@@ -0,0 +1,347 @@
|
||||
"""
|
||||
Пример обучения GAN модели для преобразования Yandex → Google карт.
|
||||
|
||||
Этот скрипт показывает, как использовать GANTrainer для обучения модели.
|
||||
"""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
# Добавляем путь к модулям
|
||||
sys.path.append(str(Path(__file__).parent.parent.parent))
|
||||
|
||||
from models.GAN.gan import create_image_gan
|
||||
from models.GAN.trainer import GANTrainer
|
||||
|
||||
|
||||
def create_simple_config():
|
||||
"""Создает простую конфигурацию для обучения."""
|
||||
config = {
|
||||
# Параметры оптимизатора
|
||||
"learning_rate": 2e-4,
|
||||
"beta1": 0.5,
|
||||
"beta2": 0.999,
|
||||
# Параметры обучения
|
||||
"batch_size": 4,
|
||||
"epochs": 100,
|
||||
# Параметры GAN
|
||||
"gan_mode": "vanilla", # "vanilla", "lsgan", или "wgangp"
|
||||
"lambda_L1": 100.0, # Вес L1 потерь
|
||||
# Регуляризация
|
||||
"grad_clip": 1.0,
|
||||
# Ранняя остановка
|
||||
"early_stopping_patience": 20,
|
||||
# Выходные данные
|
||||
"output_dir": "runs/gan_training",
|
||||
# Логирование
|
||||
"log_interval": 10, # Логировать каждые N батчей
|
||||
"save_interval": 5, # Сохранять чекпоинт каждые N эпох
|
||||
}
|
||||
return config
|
||||
|
||||
|
||||
def create_advanced_config():
|
||||
"""Создает расширенную конфигурацию для обучения."""
|
||||
config = {
|
||||
# Параметры оптимизатора
|
||||
"learning_rate": 2e-4,
|
||||
"beta1": 0.5,
|
||||
"beta2": 0.999,
|
||||
# Планировщик learning rate
|
||||
"use_scheduler": True,
|
||||
"scheduler_type": "linear", # "linear", "cosine", или "plateau"
|
||||
"scheduler_start_epoch": 50,
|
||||
"scheduler_end_epoch": 100,
|
||||
# Параметры обучения
|
||||
"batch_size": 8,
|
||||
"epochs": 200,
|
||||
# Параметры GAN
|
||||
"gan_mode": "lsgan", # LSGAN обычно более стабилен
|
||||
"lambda_L1": 100.0,
|
||||
# Аугментация данных
|
||||
"augmentation": {
|
||||
"random_crop": True,
|
||||
"crop_size": 256,
|
||||
"random_flip": True,
|
||||
"color_jitter": True,
|
||||
"brightness": 0.2,
|
||||
"contrast": 0.2,
|
||||
"saturation": 0.2,
|
||||
"hue": 0.1,
|
||||
},
|
||||
# Регуляризация
|
||||
"grad_clip": 1.0,
|
||||
"weight_decay": 1e-4,
|
||||
# Ранняя остановка
|
||||
"early_stopping_patience": 30,
|
||||
"early_stopping_min_delta": 1e-4,
|
||||
# Выходные данные
|
||||
"output_dir": "runs/gan_advanced",
|
||||
# Логирование
|
||||
"log_interval": 20,
|
||||
"save_interval": 10,
|
||||
"save_best_only": True, # Сохранять только лучшую модель
|
||||
# Визуализация
|
||||
"visualize_samples": True,
|
||||
"num_visualize": 4,
|
||||
"visualize_interval": 5, # Визуализировать каждые N эпох
|
||||
}
|
||||
return config
|
||||
|
||||
|
||||
def print_config_summary(config):
|
||||
"""Печатает сводку конфигурации."""
|
||||
print("=" * 60)
|
||||
print("Конфигурация обучения GAN")
|
||||
print("=" * 60)
|
||||
|
||||
print(f"\nПараметры модели:")
|
||||
print(f" Режим GAN: {config.get('gan_mode', 'vanilla')}")
|
||||
print(f" Вес L1 потерь: {config.get('lambda_L1', 100.0)}")
|
||||
|
||||
print(f"\nПараметры обучения:")
|
||||
print(f" Learning rate: {config.get('learning_rate', 2e-4)}")
|
||||
print(f" Batch size: {config.get('batch_size', 4)}")
|
||||
print(f" Эпох: {config.get('epochs', 100)}")
|
||||
print(f" Beta1: {config.get('beta1', 0.5)}")
|
||||
print(f" Beta2: {config.get('beta2', 0.999)}")
|
||||
|
||||
if config.get("use_scheduler", False):
|
||||
print(f" Планировщик LR: {config.get('scheduler_type', 'linear')}")
|
||||
|
||||
print(f"\nРегуляризация:")
|
||||
print(f" Gradient clipping: {config.get('grad_clip', 1.0)}")
|
||||
if "weight_decay" in config:
|
||||
print(f" Weight decay: {config['weight_decay']}")
|
||||
|
||||
print(f"\nРанняя остановка:")
|
||||
if config.get("early_stopping_patience", 0) > 0:
|
||||
print(f" Patience: {config['early_stopping_patience']} эпох")
|
||||
if "early_stopping_min_delta" in config:
|
||||
print(f" Min delta: {config['early_stopping_min_delta']}")
|
||||
|
||||
print(f"\nВыходные данные:")
|
||||
print(f" Директория: {config.get('output_dir', 'runs/gan')}")
|
||||
print(f" Интервал сохранения: {config.get('save_interval', 5)} эпох")
|
||||
|
||||
print(f"\nЛогирование:")
|
||||
print(f" Интервал логирования: {config.get('log_interval', 10)} батчей")
|
||||
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
def setup_training():
|
||||
"""Настраивает обучение."""
|
||||
print("Настройка обучения GAN...")
|
||||
|
||||
# Выбираем конфигурацию
|
||||
use_advanced = False # Измените на True для расширенной конфигурации
|
||||
|
||||
if use_advanced:
|
||||
config = create_advanced_config()
|
||||
else:
|
||||
config = create_simple_config()
|
||||
|
||||
# Печатаем сводку конфигурации
|
||||
print_config_summary(config)
|
||||
|
||||
# Устройство
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
print(f"\nИспользуемое устройство: {device}")
|
||||
|
||||
if device.type == "cuda":
|
||||
print(f" GPU: {torch.cuda.get_device_name(0)}")
|
||||
print(
|
||||
f" Память: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB"
|
||||
)
|
||||
|
||||
# Создаем модель
|
||||
print("\nСоздание модели...")
|
||||
model = create_image_gan(
|
||||
input_channels=3,
|
||||
output_channels=3,
|
||||
gan_mode=config.get("gan_mode", "vanilla"),
|
||||
lambda_L1=config.get("lambda_L1", 100.0),
|
||||
use_cuda=(device.type == "cuda"),
|
||||
)
|
||||
|
||||
# Создаем даталоадеры
|
||||
print("\nСоздание даталоадеров...")
|
||||
# ЗАМЕНИТЕ ЭТО НА ВАШИ РЕАЛЬНЫЕ ДАННЫЕ
|
||||
# Пример:
|
||||
# from your_dataset_module import create_data_loaders
|
||||
# train_loader, val_loader = create_data_loaders(
|
||||
# data_dir="ваш/путь/к/данным",
|
||||
# batch_size=config["batch_size"],
|
||||
# image_size=(256, 256),
|
||||
# augment=config.get("augmentation", None),
|
||||
# )
|
||||
|
||||
# Для примера создаем фиктивные даталоадеры
|
||||
# ВАЖНО: Замените это на реальные данные!
|
||||
print(" ВНИМАНИЕ: Используются фиктивные данные!")
|
||||
print(" Замените на реальные даталоадеры!")
|
||||
|
||||
import numpy as np
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
class DummyDataset(Dataset):
|
||||
def __init__(self, num_samples=100):
|
||||
self.num_samples = num_samples
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samples
|
||||
|
||||
def __getitem__(self, idx):
|
||||
# Фиктивные данные для примера
|
||||
yandex_img = torch.randn(3, 256, 256)
|
||||
google_img = torch.randn(3, 256, 256)
|
||||
return {"yandex_img": yandex_img, "google_img": google_img}
|
||||
|
||||
train_dataset = DummyDataset(num_samples=100)
|
||||
val_dataset = DummyDataset(num_samples=20)
|
||||
|
||||
train_loader = DataLoader(
|
||||
train_dataset,
|
||||
batch_size=config.get("batch_size", 4),
|
||||
shuffle=True,
|
||||
num_workers=0,
|
||||
)
|
||||
|
||||
val_loader = DataLoader(
|
||||
val_dataset,
|
||||
batch_size=config.get("batch_size", 4),
|
||||
shuffle=False,
|
||||
num_workers=0,
|
||||
)
|
||||
|
||||
print(f" Размер обучающего набора: {len(train_dataset)}")
|
||||
print(f" Размер валидационного набора: {len(val_dataset)}")
|
||||
print(f" Батчей в эпохе: {len(train_loader)}")
|
||||
|
||||
# Создаем тренер
|
||||
print("\nСоздание тренера...")
|
||||
trainer = GANTrainer(
|
||||
model=model,
|
||||
train_loader=train_loader,
|
||||
val_loader=val_loader,
|
||||
device=device,
|
||||
config=config,
|
||||
)
|
||||
|
||||
return trainer, config
|
||||
|
||||
|
||||
def train_model(trainer, config):
|
||||
"""Запускает обучение модели."""
|
||||
print("\n" + "=" * 60)
|
||||
print("Начало обучения")
|
||||
print("=" * 60)
|
||||
|
||||
epochs = config.get("epochs", 100)
|
||||
|
||||
try:
|
||||
trainer.train(num_epochs=epochs)
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("Обучение завершено успешно!")
|
||||
print("=" * 60)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n\nОбучение прервано пользователем.")
|
||||
print("Сохранение текущего состояния...")
|
||||
trainer.save_checkpoint(is_best=False)
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n\nОшибка при обучении: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
# Пытаемся сохранить чекпоинт при ошибке
|
||||
try:
|
||||
trainer.save_checkpoint(is_best=False)
|
||||
print("Текущее состояние сохранено.")
|
||||
except:
|
||||
print("Не удалось сохранить состояние.")
|
||||
|
||||
|
||||
def evaluate_model(trainer, test_loader=None):
|
||||
"""Оценивает обученную модель."""
|
||||
print("\n" + "=" * 60)
|
||||
print("Оценка модели")
|
||||
print("=" * 60)
|
||||
|
||||
if test_loader is None:
|
||||
print("Тестовый даталоадер не предоставлен.")
|
||||
print("Используется валидационный даталоадер для оценки.")
|
||||
test_loader = trainer.val_loader
|
||||
|
||||
metrics = trainer.evaluate(test_loader)
|
||||
|
||||
print("\nМетрики оценки:")
|
||||
for key, value in metrics.items():
|
||||
print(f" {key}: {value:.6f}")
|
||||
|
||||
return metrics
|
||||
|
||||
|
||||
def generate_examples(model, device, num_examples=4):
|
||||
"""Генерирует примеры преобразования."""
|
||||
print("\n" + "=" * 60)
|
||||
print("Генерация примеров")
|
||||
print("=" * 60)
|
||||
|
||||
model.eval()
|
||||
|
||||
# Создаем фиктивные входные данные
|
||||
yandex_input = torch.randn(num_examples, 3, 256, 256).to(device)
|
||||
|
||||
with torch.no_grad():
|
||||
google_output = model(yandex_input)
|
||||
|
||||
print(f"Сгенерировано {num_examples} примеров")
|
||||
print(f"Размер входных данных: {yandex_input.shape}")
|
||||
print(f"Размер выходных данных: {google_output.shape}")
|
||||
|
||||
# Сохраняем примеры (в реальном коде сохраняйте как изображения)
|
||||
print("\nПримеры сгенерированы.")
|
||||
print("В реальном коде сохраняйте их как изображения для визуализации.")
|
||||
|
||||
return yandex_input, google_output
|
||||
|
||||
|
||||
def main():
|
||||
"""Основная функция."""
|
||||
print("=" * 60)
|
||||
print("Пример обучения GAN для преобразования Yandex → Google")
|
||||
print("=" * 60)
|
||||
|
||||
# Настройка
|
||||
trainer, config = setup_training()
|
||||
|
||||
# Обучение
|
||||
train_model(trainer, config)
|
||||
|
||||
# Оценка (требует реальных тестовых данных)
|
||||
# evaluate_model(trainer)
|
||||
|
||||
# Генерация примеров
|
||||
# generate_examples(trainer.model, trainer.device)
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("Скрипт завершен.")
|
||||
print("=" * 60)
|
||||
print("\nСледующие шаги:")
|
||||
print("1. Замените фиктивные даталоадеры на реальные данные")
|
||||
print("2. Настройте параметры в create_simple_config()")
|
||||
print("3. Запустите обучение с реальными данными")
|
||||
print("4. Визуализируйте результаты")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user