30 lines
824 B
Python
30 lines
824 B
Python
"""Main entry point for GAN training."""
|
|
|
|
from config import create_config
|
|
from dataloader import create_data_loaders
|
|
from model import create_gan
|
|
from trainer import create_trainer
|
|
|
|
|
|
def main():
|
|
"""Run training pipeline."""
|
|
config = create_config()
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
# Create components
|
|
model = create_gan(use_cuda=False) # Set to True to use GPU
|
|
train_loader, val_loader = create_data_loaders(
|
|
root_dir=config["data_dir"],
|
|
batch_size=config["batch_size"],
|
|
image_size=tuple(config["image_size"]),
|
|
num_workers=config["num_workers"],
|
|
)
|
|
trainer = create_trainer(model, train_loader, val_loader, config)
|
|
|
|
# Train
|
|
trainer.train(config["epochs"])
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import torch
|
|
main() |