Files
autopilot/models/GAN/train_example.py
2026-02-20 16:52:02 +03:00

348 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
Пример обучения GAN модели для преобразования Yandex → Google карт.
Этот скрипт показывает, как использовать GANTrainer для обучения модели.
"""
import sys
from pathlib import Path
import torch
from torch.utils.data import DataLoader
# Добавляем путь к модулям
sys.path.append(str(Path(__file__).parent.parent.parent))
from models.GAN.gan import create_image_gan
from models.GAN.trainer import GANTrainer
def create_simple_config():
"""Создает простую конфигурацию для обучения."""
config = {
# Параметры оптимизатора
"learning_rate": 2e-4,
"beta1": 0.5,
"beta2": 0.999,
# Параметры обучения
"batch_size": 4,
"epochs": 100,
# Параметры GAN
"gan_mode": "vanilla", # "vanilla", "lsgan", или "wgangp"
"lambda_L1": 100.0, # Вес L1 потерь
# Регуляризация
"grad_clip": 1.0,
# Ранняя остановка
"early_stopping_patience": 20,
# Выходные данные
"output_dir": "runs/gan_training",
# Логирование
"log_interval": 10, # Логировать каждые N батчей
"save_interval": 5, # Сохранять чекпоинт каждые N эпох
}
return config
def create_advanced_config():
"""Создает расширенную конфигурацию для обучения."""
config = {
# Параметры оптимизатора
"learning_rate": 2e-4,
"beta1": 0.5,
"beta2": 0.999,
# Планировщик learning rate
"use_scheduler": True,
"scheduler_type": "linear", # "linear", "cosine", или "plateau"
"scheduler_start_epoch": 50,
"scheduler_end_epoch": 100,
# Параметры обучения
"batch_size": 8,
"epochs": 200,
# Параметры GAN
"gan_mode": "lsgan", # LSGAN обычно более стабилен
"lambda_L1": 100.0,
# Аугментация данных
"augmentation": {
"random_crop": True,
"crop_size": 256,
"random_flip": True,
"color_jitter": True,
"brightness": 0.2,
"contrast": 0.2,
"saturation": 0.2,
"hue": 0.1,
},
# Регуляризация
"grad_clip": 1.0,
"weight_decay": 1e-4,
# Ранняя остановка
"early_stopping_patience": 30,
"early_stopping_min_delta": 1e-4,
# Выходные данные
"output_dir": "runs/gan_advanced",
# Логирование
"log_interval": 20,
"save_interval": 10,
"save_best_only": True, # Сохранять только лучшую модель
# Визуализация
"visualize_samples": True,
"num_visualize": 4,
"visualize_interval": 5, # Визуализировать каждые N эпох
}
return config
def print_config_summary(config):
"""Печатает сводку конфигурации."""
print("=" * 60)
print("Конфигурация обучения GAN")
print("=" * 60)
print(f"\nПараметры модели:")
print(f" Режим GAN: {config.get('gan_mode', 'vanilla')}")
print(f" Вес L1 потерь: {config.get('lambda_L1', 100.0)}")
print(f"\nПараметры обучения:")
print(f" Learning rate: {config.get('learning_rate', 2e-4)}")
print(f" Batch size: {config.get('batch_size', 4)}")
print(f" Эпох: {config.get('epochs', 100)}")
print(f" Beta1: {config.get('beta1', 0.5)}")
print(f" Beta2: {config.get('beta2', 0.999)}")
if config.get("use_scheduler", False):
print(f" Планировщик LR: {config.get('scheduler_type', 'linear')}")
print(f"\nРегуляризация:")
print(f" Gradient clipping: {config.get('grad_clip', 1.0)}")
if "weight_decay" in config:
print(f" Weight decay: {config['weight_decay']}")
print(f"\nРанняя остановка:")
if config.get("early_stopping_patience", 0) > 0:
print(f" Patience: {config['early_stopping_patience']} эпох")
if "early_stopping_min_delta" in config:
print(f" Min delta: {config['early_stopping_min_delta']}")
print(f"\nВыходные данные:")
print(f" Директория: {config.get('output_dir', 'runs/gan')}")
print(f" Интервал сохранения: {config.get('save_interval', 5)} эпох")
print(f"\nЛогирование:")
print(f" Интервал логирования: {config.get('log_interval', 10)} батчей")
print("=" * 60)
def setup_training():
"""Настраивает обучение."""
print("Настройка обучения GAN...")
# Выбираем конфигурацию
use_advanced = False # Измените на True для расширенной конфигурации
if use_advanced:
config = create_advanced_config()
else:
config = create_simple_config()
# Печатаем сводку конфигурации
print_config_summary(config)
# Устройство
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"\nИспользуемое устройство: {device}")
if device.type == "cuda":
print(f" GPU: {torch.cuda.get_device_name(0)}")
print(
f" Память: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB"
)
# Создаем модель
print("\nСоздание модели...")
model = create_image_gan(
input_channels=3,
output_channels=3,
gan_mode=config.get("gan_mode", "vanilla"),
lambda_L1=config.get("lambda_L1", 100.0),
use_cuda=(device.type == "cuda"),
)
# Создаем даталоадеры
print("\nСоздание даталоадеров...")
# ЗАМЕНИТЕ ЭТО НА ВАШИ РЕАЛЬНЫЕ ДАННЫЕ
# Пример:
# from your_dataset_module import create_data_loaders
# train_loader, val_loader = create_data_loaders(
# data_dir="ваш/путь/к/данным",
# batch_size=config["batch_size"],
# image_size=(256, 256),
# augment=config.get("augmentation", None),
# )
# Для примера создаем фиктивные даталоадеры
# ВАЖНО: Замените это на реальные данные!
print(" ВНИМАНИЕ: Используются фиктивные данные!")
print(" Замените на реальные даталоадеры!")
import numpy as np
from torch.utils.data import Dataset
class DummyDataset(Dataset):
def __init__(self, num_samples=100):
self.num_samples = num_samples
def __len__(self):
return self.num_samples
def __getitem__(self, idx):
# Фиктивные данные для примера
yandex_img = torch.randn(3, 256, 256)
google_img = torch.randn(3, 256, 256)
return {"yandex_img": yandex_img, "google_img": google_img}
train_dataset = DummyDataset(num_samples=100)
val_dataset = DummyDataset(num_samples=20)
train_loader = DataLoader(
train_dataset,
batch_size=config.get("batch_size", 4),
shuffle=True,
num_workers=0,
)
val_loader = DataLoader(
val_dataset,
batch_size=config.get("batch_size", 4),
shuffle=False,
num_workers=0,
)
print(f" Размер обучающего набора: {len(train_dataset)}")
print(f" Размер валидационного набора: {len(val_dataset)}")
print(f" Батчей в эпохе: {len(train_loader)}")
# Создаем тренер
print("\nСоздание тренера...")
trainer = GANTrainer(
model=model,
train_loader=train_loader,
val_loader=val_loader,
device=device,
config=config,
)
return trainer, config
def train_model(trainer, config):
"""Запускает обучение модели."""
print("\n" + "=" * 60)
print("Начало обучения")
print("=" * 60)
epochs = config.get("epochs", 100)
try:
trainer.train(num_epochs=epochs)
print("\n" + "=" * 60)
print("Обучение завершено успешно!")
print("=" * 60)
except KeyboardInterrupt:
print("\n\nОбучение прервано пользователем.")
print("Сохранение текущего состояния...")
trainer.save_checkpoint(is_best=False)
except Exception as e:
print(f"\n\nОшибка при обучении: {e}")
import traceback
traceback.print_exc()
# Пытаемся сохранить чекпоинт при ошибке
try:
trainer.save_checkpoint(is_best=False)
print("Текущее состояние сохранено.")
except:
print("Не удалось сохранить состояние.")
def evaluate_model(trainer, test_loader=None):
"""Оценивает обученную модель."""
print("\n" + "=" * 60)
print("Оценка модели")
print("=" * 60)
if test_loader is None:
print("Тестовый даталоадер не предоставлен.")
print("Используется валидационный даталоадер для оценки.")
test_loader = trainer.val_loader
metrics = trainer.evaluate(test_loader)
print("\nМетрики оценки:")
for key, value in metrics.items():
print(f" {key}: {value:.6f}")
return metrics
def generate_examples(model, device, num_examples=4):
"""Генерирует примеры преобразования."""
print("\n" + "=" * 60)
print("Генерация примеров")
print("=" * 60)
model.eval()
# Создаем фиктивные входные данные
yandex_input = torch.randn(num_examples, 3, 256, 256).to(device)
with torch.no_grad():
google_output = model(yandex_input)
print(f"Сгенерировано {num_examples} примеров")
print(f"Размер входных данных: {yandex_input.shape}")
print(f"Размер выходных данных: {google_output.shape}")
# Сохраняем примеры (в реальном коде сохраняйте как изображения)
print("\nПримеры сгенерированы.")
print("В реальном коде сохраняйте их как изображения для визуализации.")
return yandex_input, google_output
def main():
"""Основная функция."""
print("=" * 60)
print("Пример обучения GAN для преобразования Yandex → Google")
print("=" * 60)
# Настройка
trainer, config = setup_training()
# Обучение
train_model(trainer, config)
# Оценка (требует реальных тестовых данных)
# evaluate_model(trainer)
# Генерация примеров
# generate_examples(trainer.model, trainer.device)
print("\n" + "=" * 60)
print("Скрипт завершен.")
print("=" * 60)
print("\nСледующие шаги:")
print("1. Замените фиктивные даталоадеры на реальные данные")
print("2. Настройте параметры в create_simple_config()")
print("3. Запустите обучение с реальными данными")
print("4. Визуализируйте результаты")
print("=" * 60)
if __name__ == "__main__":
main()