feat: add models
This commit is contained in:
136
models/GAN/minimal_example.py
Normal file
136
models/GAN/minimal_example.py
Normal file
@@ -0,0 +1,136 @@
|
||||
"""
|
||||
Минимальный пример использования GAN trainer для преобразования Yandex → Google карт.
|
||||
|
||||
Этот пример показывает самый простой способ использования тренера.
|
||||
"""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
|
||||
# Добавляем путь к модулям
|
||||
sys.path.append(str(Path(__file__).parent.parent.parent))
|
||||
|
||||
from models.GAN.gan import create_image_gan
|
||||
from models.GAN.trainer import GANTrainer
|
||||
|
||||
|
||||
class SimpleMapDataset(Dataset):
|
||||
"""Простой датасет с фиктивными данными для примера."""
|
||||
|
||||
def __init__(self, num_samples=100, image_size=(256, 256)):
|
||||
self.num_samples = num_samples
|
||||
self.image_size = image_size
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samples
|
||||
|
||||
def __getitem__(self, idx):
|
||||
# Создаем фиктивные изображения
|
||||
# В реальном коде замените на загрузку реальных изображений
|
||||
yandex_img = torch.randn(3, self.image_size[0], self.image_size[1])
|
||||
google_img = torch.randn(3, self.image_size[0], self.image_size[1])
|
||||
|
||||
return {"yandex_img": yandex_img, "google_img": google_img}
|
||||
|
||||
|
||||
def main():
|
||||
"""Основная функция минимального примера."""
|
||||
print("Минимальный пример использования GAN trainer")
|
||||
print("=" * 50)
|
||||
|
||||
# 1. Конфигурация (минимальный набор параметров)
|
||||
config = {
|
||||
"learning_rate": 2e-4,
|
||||
"batch_size": 4,
|
||||
"output_dir": "runs/gan_minimal",
|
||||
}
|
||||
|
||||
# 2. Устройство (CPU или GPU)
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
print(f"Используемое устройство: {device}")
|
||||
|
||||
# 3. Создание модели
|
||||
print("\nСоздание GAN модели...")
|
||||
model = create_image_gan(
|
||||
input_channels=3,
|
||||
output_channels=3,
|
||||
gan_mode="vanilla", # Простейший режим
|
||||
lambda_L1=100.0, # Стандартный вес L1 потерь
|
||||
use_cuda=(device.type == "cuda"),
|
||||
)
|
||||
|
||||
# 4. Создание даталоадеров
|
||||
print("Создание даталоадеров...")
|
||||
train_dataset = SimpleMapDataset(num_samples=50)
|
||||
val_dataset = SimpleMapDataset(num_samples=10)
|
||||
|
||||
train_loader = DataLoader(
|
||||
train_dataset,
|
||||
batch_size=config["batch_size"],
|
||||
shuffle=True,
|
||||
num_workers=0,
|
||||
)
|
||||
|
||||
val_loader = DataLoader(
|
||||
val_dataset,
|
||||
batch_size=config["batch_size"],
|
||||
shuffle=False,
|
||||
num_workers=0,
|
||||
)
|
||||
|
||||
print(f" Обучающих примеров: {len(train_dataset)}")
|
||||
print(f" Валидационных примеров: {len(val_dataset)}")
|
||||
|
||||
# 5. Создание тренера
|
||||
print("\nСоздание тренера...")
|
||||
trainer = GANTrainer(
|
||||
model=model,
|
||||
train_loader=train_loader,
|
||||
val_loader=val_loader,
|
||||
device=device,
|
||||
config=config,
|
||||
)
|
||||
|
||||
# 6. Обучение на небольшом количестве эпох
|
||||
print("\nЗапуск обучения (3 эпохи для примера)...")
|
||||
print("=" * 50)
|
||||
|
||||
trainer.train(num_epochs=3)
|
||||
|
||||
# 7. Генерация примеров
|
||||
print("\nГенерация примеров преобразования...")
|
||||
model.eval()
|
||||
|
||||
# Создаем тестовые данные
|
||||
test_yandex = torch.randn(2, 3, 256, 256).to(device)
|
||||
|
||||
with torch.no_grad():
|
||||
generated_google = model(test_yandex)
|
||||
|
||||
print(f"Входные изображения: {test_yandex.shape}")
|
||||
print(f"Сгенерированные изображения: {generated_google.shape}")
|
||||
print(
|
||||
f"Диапазон значений: [{generated_google.min():.3f}, {generated_google.max():.3f}]"
|
||||
)
|
||||
|
||||
# 8. Сохранение финальной модели
|
||||
print("\nСохранение модели...")
|
||||
model_save_path = "gan_model_minimal.pth"
|
||||
torch.save(model.state_dict(), model_save_path)
|
||||
print(f"Модель сохранена в: {model_save_path}")
|
||||
|
||||
print("\n" + "=" * 50)
|
||||
print("Минимальный пример завершен!")
|
||||
print("\nДля реального использования:")
|
||||
print("1. Замените SimpleMapDataset на ваш реальный датасет")
|
||||
print("2. Настройте параметры в config")
|
||||
print("3. Увеличьте количество эпох (например, до 100)")
|
||||
print("4. Используйте реальные изображения карт")
|
||||
print("=" * 50)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user