137 lines
4.9 KiB
Python
137 lines
4.9 KiB
Python
"""
|
||
Минимальный пример использования GAN trainer для преобразования Yandex → Google карт.
|
||
|
||
Этот пример показывает самый простой способ использования тренера.
|
||
"""
|
||
|
||
import sys
|
||
from pathlib import Path
|
||
|
||
import torch
|
||
from torch.utils.data import DataLoader, Dataset
|
||
|
||
# Добавляем путь к модулям
|
||
sys.path.append(str(Path(__file__).parent.parent.parent))
|
||
|
||
from models.GAN.gan import create_image_gan
|
||
from models.GAN.trainer import GANTrainer
|
||
|
||
|
||
class SimpleMapDataset(Dataset):
|
||
"""Простой датасет с фиктивными данными для примера."""
|
||
|
||
def __init__(self, num_samples=100, image_size=(256, 256)):
|
||
self.num_samples = num_samples
|
||
self.image_size = image_size
|
||
|
||
def __len__(self):
|
||
return self.num_samples
|
||
|
||
def __getitem__(self, idx):
|
||
# Создаем фиктивные изображения
|
||
# В реальном коде замените на загрузку реальных изображений
|
||
yandex_img = torch.randn(3, self.image_size[0], self.image_size[1])
|
||
google_img = torch.randn(3, self.image_size[0], self.image_size[1])
|
||
|
||
return {"yandex_img": yandex_img, "google_img": google_img}
|
||
|
||
|
||
def main():
|
||
"""Основная функция минимального примера."""
|
||
print("Минимальный пример использования GAN trainer")
|
||
print("=" * 50)
|
||
|
||
# 1. Конфигурация (минимальный набор параметров)
|
||
config = {
|
||
"learning_rate": 2e-4,
|
||
"batch_size": 4,
|
||
"output_dir": "runs/gan_minimal",
|
||
}
|
||
|
||
# 2. Устройство (CPU или GPU)
|
||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||
print(f"Используемое устройство: {device}")
|
||
|
||
# 3. Создание модели
|
||
print("\nСоздание GAN модели...")
|
||
model = create_image_gan(
|
||
input_channels=3,
|
||
output_channels=3,
|
||
gan_mode="vanilla", # Простейший режим
|
||
lambda_L1=100.0, # Стандартный вес L1 потерь
|
||
use_cuda=(device.type == "cuda"),
|
||
)
|
||
|
||
# 4. Создание даталоадеров
|
||
print("Создание даталоадеров...")
|
||
train_dataset = SimpleMapDataset(num_samples=50)
|
||
val_dataset = SimpleMapDataset(num_samples=10)
|
||
|
||
train_loader = DataLoader(
|
||
train_dataset,
|
||
batch_size=config["batch_size"],
|
||
shuffle=True,
|
||
num_workers=0,
|
||
)
|
||
|
||
val_loader = DataLoader(
|
||
val_dataset,
|
||
batch_size=config["batch_size"],
|
||
shuffle=False,
|
||
num_workers=0,
|
||
)
|
||
|
||
print(f" Обучающих примеров: {len(train_dataset)}")
|
||
print(f" Валидационных примеров: {len(val_dataset)}")
|
||
|
||
# 5. Создание тренера
|
||
print("\nСоздание тренера...")
|
||
trainer = GANTrainer(
|
||
model=model,
|
||
train_loader=train_loader,
|
||
val_loader=val_loader,
|
||
device=device,
|
||
config=config,
|
||
)
|
||
|
||
# 6. Обучение на небольшом количестве эпох
|
||
print("\nЗапуск обучения (3 эпохи для примера)...")
|
||
print("=" * 50)
|
||
|
||
trainer.train(num_epochs=3)
|
||
|
||
# 7. Генерация примеров
|
||
print("\nГенерация примеров преобразования...")
|
||
model.eval()
|
||
|
||
# Создаем тестовые данные
|
||
test_yandex = torch.randn(2, 3, 256, 256).to(device)
|
||
|
||
with torch.no_grad():
|
||
generated_google = model(test_yandex)
|
||
|
||
print(f"Входные изображения: {test_yandex.shape}")
|
||
print(f"Сгенерированные изображения: {generated_google.shape}")
|
||
print(
|
||
f"Диапазон значений: [{generated_google.min():.3f}, {generated_google.max():.3f}]"
|
||
)
|
||
|
||
# 8. Сохранение финальной модели
|
||
print("\nСохранение модели...")
|
||
model_save_path = "gan_model_minimal.pth"
|
||
torch.save(model.state_dict(), model_save_path)
|
||
print(f"Модель сохранена в: {model_save_path}")
|
||
|
||
print("\n" + "=" * 50)
|
||
print("Минимальный пример завершен!")
|
||
print("\nДля реального использования:")
|
||
print("1. Замените SimpleMapDataset на ваш реальный датасет")
|
||
print("2. Настройте параметры в config")
|
||
print("3. Увеличьте количество эпох (например, до 100)")
|
||
print("4. Используйте реальные изображения карт")
|
||
print("=" * 50)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|