ref: simplify and modularize GAN codebase
This commit is contained in:
30
models/GAN/main.py
Normal file
30
models/GAN/main.py
Normal file
@@ -0,0 +1,30 @@
|
||||
"""Main entry point for GAN training."""
|
||||
|
||||
from config import create_config
|
||||
from dataloader import create_data_loaders
|
||||
from model import create_gan
|
||||
from trainer import create_trainer
|
||||
|
||||
|
||||
def main():
|
||||
"""Run training pipeline."""
|
||||
config = create_config()
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
# Create components
|
||||
model = create_gan(use_cuda=False) # Set to True to use GPU
|
||||
train_loader, val_loader = create_data_loaders(
|
||||
root_dir=config["data_dir"],
|
||||
batch_size=config["batch_size"],
|
||||
image_size=tuple(config["image_size"]),
|
||||
num_workers=config["num_workers"],
|
||||
)
|
||||
trainer = create_trainer(model, train_loader, val_loader, config)
|
||||
|
||||
# Train
|
||||
trainer.train(config["epochs"])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import torch
|
||||
main()
|
||||
Reference in New Issue
Block a user