feat: add models

This commit is contained in:
2026-02-20 16:52:02 +03:00
parent 6040f3b253
commit 0cc210968f
11 changed files with 4488 additions and 48 deletions

View File

@@ -0,0 +1,136 @@
"""
Минимальный пример использования GAN trainer для преобразования Yandex → Google карт.
Этот пример показывает самый простой способ использования тренера.
"""
import sys
from pathlib import Path
import torch
from torch.utils.data import DataLoader, Dataset
# Добавляем путь к модулям
sys.path.append(str(Path(__file__).parent.parent.parent))
from models.GAN.gan import create_image_gan
from models.GAN.trainer import GANTrainer
class SimpleMapDataset(Dataset):
"""Простой датасет с фиктивными данными для примера."""
def __init__(self, num_samples=100, image_size=(256, 256)):
self.num_samples = num_samples
self.image_size = image_size
def __len__(self):
return self.num_samples
def __getitem__(self, idx):
# Создаем фиктивные изображения
# В реальном коде замените на загрузку реальных изображений
yandex_img = torch.randn(3, self.image_size[0], self.image_size[1])
google_img = torch.randn(3, self.image_size[0], self.image_size[1])
return {"yandex_img": yandex_img, "google_img": google_img}
def main():
"""Основная функция минимального примера."""
print("Минимальный пример использования GAN trainer")
print("=" * 50)
# 1. Конфигурация (минимальный набор параметров)
config = {
"learning_rate": 2e-4,
"batch_size": 4,
"output_dir": "runs/gan_minimal",
}
# 2. Устройство (CPU или GPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Используемое устройство: {device}")
# 3. Создание модели
print("\nСоздание GAN модели...")
model = create_image_gan(
input_channels=3,
output_channels=3,
gan_mode="vanilla", # Простейший режим
lambda_L1=100.0, # Стандартный вес L1 потерь
use_cuda=(device.type == "cuda"),
)
# 4. Создание даталоадеров
print("Создание даталоадеров...")
train_dataset = SimpleMapDataset(num_samples=50)
val_dataset = SimpleMapDataset(num_samples=10)
train_loader = DataLoader(
train_dataset,
batch_size=config["batch_size"],
shuffle=True,
num_workers=0,
)
val_loader = DataLoader(
val_dataset,
batch_size=config["batch_size"],
shuffle=False,
num_workers=0,
)
print(f" Обучающих примеров: {len(train_dataset)}")
print(f" Валидационных примеров: {len(val_dataset)}")
# 5. Создание тренера
print("\nСоздание тренера...")
trainer = GANTrainer(
model=model,
train_loader=train_loader,
val_loader=val_loader,
device=device,
config=config,
)
# 6. Обучение на небольшом количестве эпох
print("\nЗапуск обучения (3 эпохи для примера)...")
print("=" * 50)
trainer.train(num_epochs=3)
# 7. Генерация примеров
print("\nГенерация примеров преобразования...")
model.eval()
# Создаем тестовые данные
test_yandex = torch.randn(2, 3, 256, 256).to(device)
with torch.no_grad():
generated_google = model(test_yandex)
print(f"Входные изображения: {test_yandex.shape}")
print(f"Сгенерированные изображения: {generated_google.shape}")
print(
f"Диапазон значений: [{generated_google.min():.3f}, {generated_google.max():.3f}]"
)
# 8. Сохранение финальной модели
print("\nСохранение модели...")
model_save_path = "gan_model_minimal.pth"
torch.save(model.state_dict(), model_save_path)
print(f"Модель сохранена в: {model_save_path}")
print("\n" + "=" * 50)
print("Минимальный пример завершен!")
print("\nДля реального использования:")
print("1. Замените SimpleMapDataset на ваш реальный датасет")
print("2. Настройте параметры в config")
print("3. Увеличьте количество эпох (например, до 100)")
print("4. Используйте реальные изображения карт")
print("=" * 50)
if __name__ == "__main__":
main()