Files
autopilot/models/GAN/main.py

30 lines
824 B
Python

"""Main entry point for GAN training."""
from config import create_config
from dataloader import create_data_loaders
from model import create_gan
from trainer import create_trainer
def main():
"""Run training pipeline."""
config = create_config()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Create components
model = create_gan(use_cuda=False) # Set to True to use GPU
train_loader, val_loader = create_data_loaders(
root_dir=config["data_dir"],
batch_size=config["batch_size"],
image_size=tuple(config["image_size"]),
num_workers=config["num_workers"],
)
trainer = create_trainer(model, train_loader, val_loader, config)
# Train
trainer.train(config["epochs"])
if __name__ == "__main__":
import torch
main()