36 lines
921 B
Python
36 lines
921 B
Python
"""Configuration for GAN training."""
|
|
|
|
|
|
def create_config():
|
|
"""Create default configuration dictionary."""
|
|
return {
|
|
# Optimizer params
|
|
"learning_rate": 2e-4,
|
|
"beta1": 0.5,
|
|
"beta2": 0.999,
|
|
# Training params
|
|
"batch_size": 32,
|
|
"epochs": 100,
|
|
# GAN params
|
|
"gan_mode": "vanilla",
|
|
"lambda_L1": 100.0,
|
|
# Regularization
|
|
"grad_clip": 1.0,
|
|
# Early stopping
|
|
"early_stopping_patience": 20,
|
|
# Output
|
|
"output_dir": "runs/gan_training",
|
|
# Logging
|
|
"log_interval": 10,
|
|
"save_interval": 5,
|
|
# Data
|
|
"data_dir": r"C:\Users\admin\Projects\autopilot\datasets\ya_go_maps\images",
|
|
"image_size": [256, 256],
|
|
"train_split": 0.8,
|
|
"num_workers": 0,
|
|
}
|
|
|
|
|
|
if __name__ == "__main__":
|
|
config = create_config()
|
|
print("Default config:", config) |