348 lines
12 KiB
Python
348 lines
12 KiB
Python
"""
|
||
Пример обучения GAN модели для преобразования Yandex → Google карт.
|
||
|
||
Этот скрипт показывает, как использовать GANTrainer для обучения модели.
|
||
"""
|
||
|
||
import sys
|
||
from pathlib import Path
|
||
|
||
import torch
|
||
from torch.utils.data import DataLoader
|
||
|
||
# Добавляем путь к модулям
|
||
sys.path.append(str(Path(__file__).parent.parent.parent))
|
||
|
||
from models.GAN.gan import create_image_gan
|
||
from models.GAN.trainer import GANTrainer
|
||
|
||
|
||
def create_simple_config():
|
||
"""Создает простую конфигурацию для обучения."""
|
||
config = {
|
||
# Параметры оптимизатора
|
||
"learning_rate": 2e-4,
|
||
"beta1": 0.5,
|
||
"beta2": 0.999,
|
||
# Параметры обучения
|
||
"batch_size": 4,
|
||
"epochs": 100,
|
||
# Параметры GAN
|
||
"gan_mode": "vanilla", # "vanilla", "lsgan", или "wgangp"
|
||
"lambda_L1": 100.0, # Вес L1 потерь
|
||
# Регуляризация
|
||
"grad_clip": 1.0,
|
||
# Ранняя остановка
|
||
"early_stopping_patience": 20,
|
||
# Выходные данные
|
||
"output_dir": "runs/gan_training",
|
||
# Логирование
|
||
"log_interval": 10, # Логировать каждые N батчей
|
||
"save_interval": 5, # Сохранять чекпоинт каждые N эпох
|
||
}
|
||
return config
|
||
|
||
|
||
def create_advanced_config():
|
||
"""Создает расширенную конфигурацию для обучения."""
|
||
config = {
|
||
# Параметры оптимизатора
|
||
"learning_rate": 2e-4,
|
||
"beta1": 0.5,
|
||
"beta2": 0.999,
|
||
# Планировщик learning rate
|
||
"use_scheduler": True,
|
||
"scheduler_type": "linear", # "linear", "cosine", или "plateau"
|
||
"scheduler_start_epoch": 50,
|
||
"scheduler_end_epoch": 100,
|
||
# Параметры обучения
|
||
"batch_size": 8,
|
||
"epochs": 200,
|
||
# Параметры GAN
|
||
"gan_mode": "lsgan", # LSGAN обычно более стабилен
|
||
"lambda_L1": 100.0,
|
||
# Аугментация данных
|
||
"augmentation": {
|
||
"random_crop": True,
|
||
"crop_size": 256,
|
||
"random_flip": True,
|
||
"color_jitter": True,
|
||
"brightness": 0.2,
|
||
"contrast": 0.2,
|
||
"saturation": 0.2,
|
||
"hue": 0.1,
|
||
},
|
||
# Регуляризация
|
||
"grad_clip": 1.0,
|
||
"weight_decay": 1e-4,
|
||
# Ранняя остановка
|
||
"early_stopping_patience": 30,
|
||
"early_stopping_min_delta": 1e-4,
|
||
# Выходные данные
|
||
"output_dir": "runs/gan_advanced",
|
||
# Логирование
|
||
"log_interval": 20,
|
||
"save_interval": 10,
|
||
"save_best_only": True, # Сохранять только лучшую модель
|
||
# Визуализация
|
||
"visualize_samples": True,
|
||
"num_visualize": 4,
|
||
"visualize_interval": 5, # Визуализировать каждые N эпох
|
||
}
|
||
return config
|
||
|
||
|
||
def print_config_summary(config):
|
||
"""Печатает сводку конфигурации."""
|
||
print("=" * 60)
|
||
print("Конфигурация обучения GAN")
|
||
print("=" * 60)
|
||
|
||
print(f"\nПараметры модели:")
|
||
print(f" Режим GAN: {config.get('gan_mode', 'vanilla')}")
|
||
print(f" Вес L1 потерь: {config.get('lambda_L1', 100.0)}")
|
||
|
||
print(f"\nПараметры обучения:")
|
||
print(f" Learning rate: {config.get('learning_rate', 2e-4)}")
|
||
print(f" Batch size: {config.get('batch_size', 4)}")
|
||
print(f" Эпох: {config.get('epochs', 100)}")
|
||
print(f" Beta1: {config.get('beta1', 0.5)}")
|
||
print(f" Beta2: {config.get('beta2', 0.999)}")
|
||
|
||
if config.get("use_scheduler", False):
|
||
print(f" Планировщик LR: {config.get('scheduler_type', 'linear')}")
|
||
|
||
print(f"\nРегуляризация:")
|
||
print(f" Gradient clipping: {config.get('grad_clip', 1.0)}")
|
||
if "weight_decay" in config:
|
||
print(f" Weight decay: {config['weight_decay']}")
|
||
|
||
print(f"\nРанняя остановка:")
|
||
if config.get("early_stopping_patience", 0) > 0:
|
||
print(f" Patience: {config['early_stopping_patience']} эпох")
|
||
if "early_stopping_min_delta" in config:
|
||
print(f" Min delta: {config['early_stopping_min_delta']}")
|
||
|
||
print(f"\nВыходные данные:")
|
||
print(f" Директория: {config.get('output_dir', 'runs/gan')}")
|
||
print(f" Интервал сохранения: {config.get('save_interval', 5)} эпох")
|
||
|
||
print(f"\nЛогирование:")
|
||
print(f" Интервал логирования: {config.get('log_interval', 10)} батчей")
|
||
|
||
print("=" * 60)
|
||
|
||
|
||
def setup_training():
|
||
"""Настраивает обучение."""
|
||
print("Настройка обучения GAN...")
|
||
|
||
# Выбираем конфигурацию
|
||
use_advanced = False # Измените на True для расширенной конфигурации
|
||
|
||
if use_advanced:
|
||
config = create_advanced_config()
|
||
else:
|
||
config = create_simple_config()
|
||
|
||
# Печатаем сводку конфигурации
|
||
print_config_summary(config)
|
||
|
||
# Устройство
|
||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||
print(f"\nИспользуемое устройство: {device}")
|
||
|
||
if device.type == "cuda":
|
||
print(f" GPU: {torch.cuda.get_device_name(0)}")
|
||
print(
|
||
f" Память: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB"
|
||
)
|
||
|
||
# Создаем модель
|
||
print("\nСоздание модели...")
|
||
model = create_image_gan(
|
||
input_channels=3,
|
||
output_channels=3,
|
||
gan_mode=config.get("gan_mode", "vanilla"),
|
||
lambda_L1=config.get("lambda_L1", 100.0),
|
||
use_cuda=(device.type == "cuda"),
|
||
)
|
||
|
||
# Создаем даталоадеры
|
||
print("\nСоздание даталоадеров...")
|
||
# ЗАМЕНИТЕ ЭТО НА ВАШИ РЕАЛЬНЫЕ ДАННЫЕ
|
||
# Пример:
|
||
# from your_dataset_module import create_data_loaders
|
||
# train_loader, val_loader = create_data_loaders(
|
||
# data_dir="ваш/путь/к/данным",
|
||
# batch_size=config["batch_size"],
|
||
# image_size=(256, 256),
|
||
# augment=config.get("augmentation", None),
|
||
# )
|
||
|
||
# Для примера создаем фиктивные даталоадеры
|
||
# ВАЖНО: Замените это на реальные данные!
|
||
print(" ВНИМАНИЕ: Используются фиктивные данные!")
|
||
print(" Замените на реальные даталоадеры!")
|
||
|
||
import numpy as np
|
||
from torch.utils.data import Dataset
|
||
|
||
class DummyDataset(Dataset):
|
||
def __init__(self, num_samples=100):
|
||
self.num_samples = num_samples
|
||
|
||
def __len__(self):
|
||
return self.num_samples
|
||
|
||
def __getitem__(self, idx):
|
||
# Фиктивные данные для примера
|
||
yandex_img = torch.randn(3, 256, 256)
|
||
google_img = torch.randn(3, 256, 256)
|
||
return {"yandex_img": yandex_img, "google_img": google_img}
|
||
|
||
train_dataset = DummyDataset(num_samples=100)
|
||
val_dataset = DummyDataset(num_samples=20)
|
||
|
||
train_loader = DataLoader(
|
||
train_dataset,
|
||
batch_size=config.get("batch_size", 4),
|
||
shuffle=True,
|
||
num_workers=0,
|
||
)
|
||
|
||
val_loader = DataLoader(
|
||
val_dataset,
|
||
batch_size=config.get("batch_size", 4),
|
||
shuffle=False,
|
||
num_workers=0,
|
||
)
|
||
|
||
print(f" Размер обучающего набора: {len(train_dataset)}")
|
||
print(f" Размер валидационного набора: {len(val_dataset)}")
|
||
print(f" Батчей в эпохе: {len(train_loader)}")
|
||
|
||
# Создаем тренер
|
||
print("\nСоздание тренера...")
|
||
trainer = GANTrainer(
|
||
model=model,
|
||
train_loader=train_loader,
|
||
val_loader=val_loader,
|
||
device=device,
|
||
config=config,
|
||
)
|
||
|
||
return trainer, config
|
||
|
||
|
||
def train_model(trainer, config):
|
||
"""Запускает обучение модели."""
|
||
print("\n" + "=" * 60)
|
||
print("Начало обучения")
|
||
print("=" * 60)
|
||
|
||
epochs = config.get("epochs", 100)
|
||
|
||
try:
|
||
trainer.train(num_epochs=epochs)
|
||
|
||
print("\n" + "=" * 60)
|
||
print("Обучение завершено успешно!")
|
||
print("=" * 60)
|
||
|
||
except KeyboardInterrupt:
|
||
print("\n\nОбучение прервано пользователем.")
|
||
print("Сохранение текущего состояния...")
|
||
trainer.save_checkpoint(is_best=False)
|
||
|
||
except Exception as e:
|
||
print(f"\n\nОшибка при обучении: {e}")
|
||
import traceback
|
||
|
||
traceback.print_exc()
|
||
|
||
# Пытаемся сохранить чекпоинт при ошибке
|
||
try:
|
||
trainer.save_checkpoint(is_best=False)
|
||
print("Текущее состояние сохранено.")
|
||
except:
|
||
print("Не удалось сохранить состояние.")
|
||
|
||
|
||
def evaluate_model(trainer, test_loader=None):
|
||
"""Оценивает обученную модель."""
|
||
print("\n" + "=" * 60)
|
||
print("Оценка модели")
|
||
print("=" * 60)
|
||
|
||
if test_loader is None:
|
||
print("Тестовый даталоадер не предоставлен.")
|
||
print("Используется валидационный даталоадер для оценки.")
|
||
test_loader = trainer.val_loader
|
||
|
||
metrics = trainer.evaluate(test_loader)
|
||
|
||
print("\nМетрики оценки:")
|
||
for key, value in metrics.items():
|
||
print(f" {key}: {value:.6f}")
|
||
|
||
return metrics
|
||
|
||
|
||
def generate_examples(model, device, num_examples=4):
|
||
"""Генерирует примеры преобразования."""
|
||
print("\n" + "=" * 60)
|
||
print("Генерация примеров")
|
||
print("=" * 60)
|
||
|
||
model.eval()
|
||
|
||
# Создаем фиктивные входные данные
|
||
yandex_input = torch.randn(num_examples, 3, 256, 256).to(device)
|
||
|
||
with torch.no_grad():
|
||
google_output = model(yandex_input)
|
||
|
||
print(f"Сгенерировано {num_examples} примеров")
|
||
print(f"Размер входных данных: {yandex_input.shape}")
|
||
print(f"Размер выходных данных: {google_output.shape}")
|
||
|
||
# Сохраняем примеры (в реальном коде сохраняйте как изображения)
|
||
print("\nПримеры сгенерированы.")
|
||
print("В реальном коде сохраняйте их как изображения для визуализации.")
|
||
|
||
return yandex_input, google_output
|
||
|
||
|
||
def main():
|
||
"""Основная функция."""
|
||
print("=" * 60)
|
||
print("Пример обучения GAN для преобразования Yandex → Google")
|
||
print("=" * 60)
|
||
|
||
# Настройка
|
||
trainer, config = setup_training()
|
||
|
||
# Обучение
|
||
train_model(trainer, config)
|
||
|
||
# Оценка (требует реальных тестовых данных)
|
||
# evaluate_model(trainer)
|
||
|
||
# Генерация примеров
|
||
# generate_examples(trainer.model, trainer.device)
|
||
|
||
print("\n" + "=" * 60)
|
||
print("Скрипт завершен.")
|
||
print("=" * 60)
|
||
print("\nСледующие шаги:")
|
||
print("1. Замените фиктивные даталоадеры на реальные данные")
|
||
print("2. Настройте параметры в create_simple_config()")
|
||
print("3. Запустите обучение с реальными данными")
|
||
print("4. Визуализируйте результаты")
|
||
print("=" * 60)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|