Files
autopilot/models/GAN/minimal_example.py
2026-02-20 16:52:02 +03:00

137 lines
4.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
Минимальный пример использования GAN trainer для преобразования Yandex → Google карт.
Этот пример показывает самый простой способ использования тренера.
"""
import sys
from pathlib import Path
import torch
from torch.utils.data import DataLoader, Dataset
# Добавляем путь к модулям
sys.path.append(str(Path(__file__).parent.parent.parent))
from models.GAN.gan import create_image_gan
from models.GAN.trainer import GANTrainer
class SimpleMapDataset(Dataset):
"""Простой датасет с фиктивными данными для примера."""
def __init__(self, num_samples=100, image_size=(256, 256)):
self.num_samples = num_samples
self.image_size = image_size
def __len__(self):
return self.num_samples
def __getitem__(self, idx):
# Создаем фиктивные изображения
# В реальном коде замените на загрузку реальных изображений
yandex_img = torch.randn(3, self.image_size[0], self.image_size[1])
google_img = torch.randn(3, self.image_size[0], self.image_size[1])
return {"yandex_img": yandex_img, "google_img": google_img}
def main():
"""Основная функция минимального примера."""
print("Минимальный пример использования GAN trainer")
print("=" * 50)
# 1. Конфигурация (минимальный набор параметров)
config = {
"learning_rate": 2e-4,
"batch_size": 4,
"output_dir": "runs/gan_minimal",
}
# 2. Устройство (CPU или GPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Используемое устройство: {device}")
# 3. Создание модели
print("\nСоздание GAN модели...")
model = create_image_gan(
input_channels=3,
output_channels=3,
gan_mode="vanilla", # Простейший режим
lambda_L1=100.0, # Стандартный вес L1 потерь
use_cuda=(device.type == "cuda"),
)
# 4. Создание даталоадеров
print("Создание даталоадеров...")
train_dataset = SimpleMapDataset(num_samples=50)
val_dataset = SimpleMapDataset(num_samples=10)
train_loader = DataLoader(
train_dataset,
batch_size=config["batch_size"],
shuffle=True,
num_workers=0,
)
val_loader = DataLoader(
val_dataset,
batch_size=config["batch_size"],
shuffle=False,
num_workers=0,
)
print(f" Обучающих примеров: {len(train_dataset)}")
print(f" Валидационных примеров: {len(val_dataset)}")
# 5. Создание тренера
print("\nСоздание тренера...")
trainer = GANTrainer(
model=model,
train_loader=train_loader,
val_loader=val_loader,
device=device,
config=config,
)
# 6. Обучение на небольшом количестве эпох
print("\nЗапуск обучения (3 эпохи для примера)...")
print("=" * 50)
trainer.train(num_epochs=3)
# 7. Генерация примеров
print("\nГенерация примеров преобразования...")
model.eval()
# Создаем тестовые данные
test_yandex = torch.randn(2, 3, 256, 256).to(device)
with torch.no_grad():
generated_google = model(test_yandex)
print(f"Входные изображения: {test_yandex.shape}")
print(f"Сгенерированные изображения: {generated_google.shape}")
print(
f"Диапазон значений: [{generated_google.min():.3f}, {generated_google.max():.3f}]"
)
# 8. Сохранение финальной модели
print("\nСохранение модели...")
model_save_path = "gan_model_minimal.pth"
torch.save(model.state_dict(), model_save_path)
print(f"Модель сохранена в: {model_save_path}")
print("\n" + "=" * 50)
print("Минимальный пример завершен!")
print("\nДля реального использования:")
print("1. Замените SimpleMapDataset на ваш реальный датасет")
print("2. Настройте параметры в config")
print("3. Увеличьте количество эпох (например, до 100)")
print("4. Используйте реальные изображения карт")
print("=" * 50)
if __name__ == "__main__":
main()