feat: add models

This commit is contained in:
2026-02-20 16:52:02 +03:00
parent 6040f3b253
commit 0cc210968f
11 changed files with 4488 additions and 48 deletions

347
models/GAN/train_example.py Normal file
View File

@@ -0,0 +1,347 @@
"""
Пример обучения GAN модели для преобразования Yandex → Google карт.
Этот скрипт показывает, как использовать GANTrainer для обучения модели.
"""
import sys
from pathlib import Path
import torch
from torch.utils.data import DataLoader
# Добавляем путь к модулям
sys.path.append(str(Path(__file__).parent.parent.parent))
from models.GAN.gan import create_image_gan
from models.GAN.trainer import GANTrainer
def create_simple_config():
"""Создает простую конфигурацию для обучения."""
config = {
# Параметры оптимизатора
"learning_rate": 2e-4,
"beta1": 0.5,
"beta2": 0.999,
# Параметры обучения
"batch_size": 4,
"epochs": 100,
# Параметры GAN
"gan_mode": "vanilla", # "vanilla", "lsgan", или "wgangp"
"lambda_L1": 100.0, # Вес L1 потерь
# Регуляризация
"grad_clip": 1.0,
# Ранняя остановка
"early_stopping_patience": 20,
# Выходные данные
"output_dir": "runs/gan_training",
# Логирование
"log_interval": 10, # Логировать каждые N батчей
"save_interval": 5, # Сохранять чекпоинт каждые N эпох
}
return config
def create_advanced_config():
"""Создает расширенную конфигурацию для обучения."""
config = {
# Параметры оптимизатора
"learning_rate": 2e-4,
"beta1": 0.5,
"beta2": 0.999,
# Планировщик learning rate
"use_scheduler": True,
"scheduler_type": "linear", # "linear", "cosine", или "plateau"
"scheduler_start_epoch": 50,
"scheduler_end_epoch": 100,
# Параметры обучения
"batch_size": 8,
"epochs": 200,
# Параметры GAN
"gan_mode": "lsgan", # LSGAN обычно более стабилен
"lambda_L1": 100.0,
# Аугментация данных
"augmentation": {
"random_crop": True,
"crop_size": 256,
"random_flip": True,
"color_jitter": True,
"brightness": 0.2,
"contrast": 0.2,
"saturation": 0.2,
"hue": 0.1,
},
# Регуляризация
"grad_clip": 1.0,
"weight_decay": 1e-4,
# Ранняя остановка
"early_stopping_patience": 30,
"early_stopping_min_delta": 1e-4,
# Выходные данные
"output_dir": "runs/gan_advanced",
# Логирование
"log_interval": 20,
"save_interval": 10,
"save_best_only": True, # Сохранять только лучшую модель
# Визуализация
"visualize_samples": True,
"num_visualize": 4,
"visualize_interval": 5, # Визуализировать каждые N эпох
}
return config
def print_config_summary(config):
"""Печатает сводку конфигурации."""
print("=" * 60)
print("Конфигурация обучения GAN")
print("=" * 60)
print(f"\nПараметры модели:")
print(f" Режим GAN: {config.get('gan_mode', 'vanilla')}")
print(f" Вес L1 потерь: {config.get('lambda_L1', 100.0)}")
print(f"\nПараметры обучения:")
print(f" Learning rate: {config.get('learning_rate', 2e-4)}")
print(f" Batch size: {config.get('batch_size', 4)}")
print(f" Эпох: {config.get('epochs', 100)}")
print(f" Beta1: {config.get('beta1', 0.5)}")
print(f" Beta2: {config.get('beta2', 0.999)}")
if config.get("use_scheduler", False):
print(f" Планировщик LR: {config.get('scheduler_type', 'linear')}")
print(f"\nРегуляризация:")
print(f" Gradient clipping: {config.get('grad_clip', 1.0)}")
if "weight_decay" in config:
print(f" Weight decay: {config['weight_decay']}")
print(f"\nРанняя остановка:")
if config.get("early_stopping_patience", 0) > 0:
print(f" Patience: {config['early_stopping_patience']} эпох")
if "early_stopping_min_delta" in config:
print(f" Min delta: {config['early_stopping_min_delta']}")
print(f"\nВыходные данные:")
print(f" Директория: {config.get('output_dir', 'runs/gan')}")
print(f" Интервал сохранения: {config.get('save_interval', 5)} эпох")
print(f"\nЛогирование:")
print(f" Интервал логирования: {config.get('log_interval', 10)} батчей")
print("=" * 60)
def setup_training():
"""Настраивает обучение."""
print("Настройка обучения GAN...")
# Выбираем конфигурацию
use_advanced = False # Измените на True для расширенной конфигурации
if use_advanced:
config = create_advanced_config()
else:
config = create_simple_config()
# Печатаем сводку конфигурации
print_config_summary(config)
# Устройство
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"\nИспользуемое устройство: {device}")
if device.type == "cuda":
print(f" GPU: {torch.cuda.get_device_name(0)}")
print(
f" Память: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB"
)
# Создаем модель
print("\nСоздание модели...")
model = create_image_gan(
input_channels=3,
output_channels=3,
gan_mode=config.get("gan_mode", "vanilla"),
lambda_L1=config.get("lambda_L1", 100.0),
use_cuda=(device.type == "cuda"),
)
# Создаем даталоадеры
print("\nСоздание даталоадеров...")
# ЗАМЕНИТЕ ЭТО НА ВАШИ РЕАЛЬНЫЕ ДАННЫЕ
# Пример:
# from your_dataset_module import create_data_loaders
# train_loader, val_loader = create_data_loaders(
# data_dir="ваш/путь/к/данным",
# batch_size=config["batch_size"],
# image_size=(256, 256),
# augment=config.get("augmentation", None),
# )
# Для примера создаем фиктивные даталоадеры
# ВАЖНО: Замените это на реальные данные!
print(" ВНИМАНИЕ: Используются фиктивные данные!")
print(" Замените на реальные даталоадеры!")
import numpy as np
from torch.utils.data import Dataset
class DummyDataset(Dataset):
def __init__(self, num_samples=100):
self.num_samples = num_samples
def __len__(self):
return self.num_samples
def __getitem__(self, idx):
# Фиктивные данные для примера
yandex_img = torch.randn(3, 256, 256)
google_img = torch.randn(3, 256, 256)
return {"yandex_img": yandex_img, "google_img": google_img}
train_dataset = DummyDataset(num_samples=100)
val_dataset = DummyDataset(num_samples=20)
train_loader = DataLoader(
train_dataset,
batch_size=config.get("batch_size", 4),
shuffle=True,
num_workers=0,
)
val_loader = DataLoader(
val_dataset,
batch_size=config.get("batch_size", 4),
shuffle=False,
num_workers=0,
)
print(f" Размер обучающего набора: {len(train_dataset)}")
print(f" Размер валидационного набора: {len(val_dataset)}")
print(f" Батчей в эпохе: {len(train_loader)}")
# Создаем тренер
print("\nСоздание тренера...")
trainer = GANTrainer(
model=model,
train_loader=train_loader,
val_loader=val_loader,
device=device,
config=config,
)
return trainer, config
def train_model(trainer, config):
"""Запускает обучение модели."""
print("\n" + "=" * 60)
print("Начало обучения")
print("=" * 60)
epochs = config.get("epochs", 100)
try:
trainer.train(num_epochs=epochs)
print("\n" + "=" * 60)
print("Обучение завершено успешно!")
print("=" * 60)
except KeyboardInterrupt:
print("\n\nОбучение прервано пользователем.")
print("Сохранение текущего состояния...")
trainer.save_checkpoint(is_best=False)
except Exception as e:
print(f"\n\nОшибка при обучении: {e}")
import traceback
traceback.print_exc()
# Пытаемся сохранить чекпоинт при ошибке
try:
trainer.save_checkpoint(is_best=False)
print("Текущее состояние сохранено.")
except:
print("Не удалось сохранить состояние.")
def evaluate_model(trainer, test_loader=None):
"""Оценивает обученную модель."""
print("\n" + "=" * 60)
print("Оценка модели")
print("=" * 60)
if test_loader is None:
print("Тестовый даталоадер не предоставлен.")
print("Используется валидационный даталоадер для оценки.")
test_loader = trainer.val_loader
metrics = trainer.evaluate(test_loader)
print("\nМетрики оценки:")
for key, value in metrics.items():
print(f" {key}: {value:.6f}")
return metrics
def generate_examples(model, device, num_examples=4):
"""Генерирует примеры преобразования."""
print("\n" + "=" * 60)
print("Генерация примеров")
print("=" * 60)
model.eval()
# Создаем фиктивные входные данные
yandex_input = torch.randn(num_examples, 3, 256, 256).to(device)
with torch.no_grad():
google_output = model(yandex_input)
print(f"Сгенерировано {num_examples} примеров")
print(f"Размер входных данных: {yandex_input.shape}")
print(f"Размер выходных данных: {google_output.shape}")
# Сохраняем примеры (в реальном коде сохраняйте как изображения)
print("\nПримеры сгенерированы.")
print("В реальном коде сохраняйте их как изображения для визуализации.")
return yandex_input, google_output
def main():
"""Основная функция."""
print("=" * 60)
print("Пример обучения GAN для преобразования Yandex → Google")
print("=" * 60)
# Настройка
trainer, config = setup_training()
# Обучение
train_model(trainer, config)
# Оценка (требует реальных тестовых данных)
# evaluate_model(trainer)
# Генерация примеров
# generate_examples(trainer.model, trainer.device)
print("\n" + "=" * 60)
print("Скрипт завершен.")
print("=" * 60)
print("\nСледующие шаги:")
print("1. Замените фиктивные даталоадеры на реальные данные")
print("2. Настройте параметры в create_simple_config()")
print("3. Запустите обучение с реальными данными")
print("4. Визуализируйте результаты")
print("=" * 60)
if __name__ == "__main__":
main()