feat: complete sian-similarity
This commit is contained in:
@@ -2,7 +2,6 @@
|
||||
Training script for image similarity estimation.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import time
|
||||
from datetime import datetime
|
||||
@@ -10,7 +9,7 @@ from datetime import datetime
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from dataloader import create_data_loaders
|
||||
from dataloader import config, create_data_loaders
|
||||
from model import SimilarityCNN, SimilarityLoss, create_similarity_model
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
@@ -191,51 +190,23 @@ class SimilarityTrainer:
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Train similarity estimation model")
|
||||
parser.add_argument(
|
||||
"--data_dir",
|
||||
type=str,
|
||||
default=r"C:\Users\admin\Projects\autopilot\datasets\ya_go_maps\images",
|
||||
)
|
||||
parser.add_argument("--batch_size", type=int, default=32)
|
||||
parser.add_argument("--epochs", type=int, default=100)
|
||||
parser.add_argument("--learning_rate", type=float, default=2e-4)
|
||||
parser.add_argument("--image_size", type=int, default=256)
|
||||
parser.add_argument("--train_split", type=float, default=0.8)
|
||||
parser.add_argument("--output_dir", type=str, default="runs/similarity")
|
||||
parser.add_argument("--num_workers", type=int, default=0)
|
||||
parser.add_argument(
|
||||
"--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu"
|
||||
)
|
||||
# Use config from dataloader.py
|
||||
config_dict = config.copy()
|
||||
|
||||
args = parser.parse_args()
|
||||
# Ensure image_size is tuple
|
||||
if isinstance(config_dict.get("image_size"), list):
|
||||
config_dict["image_size"] = tuple(config_dict["image_size"])
|
||||
|
||||
config = {
|
||||
"data_dir": args.data_dir,
|
||||
"batch_size": args.batch_size,
|
||||
"epochs": args.epochs,
|
||||
"learning_rate": args.learning_rate,
|
||||
"image_size": (args.image_size, args.image_size),
|
||||
"train_split": args.train_split,
|
||||
"output_dir": args.output_dir,
|
||||
"num_workers": args.num_workers,
|
||||
"log_interval": 10,
|
||||
"save_interval": 5,
|
||||
"early_stopping_patience": 20,
|
||||
"beta1": 0.5,
|
||||
"beta2": 0.999,
|
||||
}
|
||||
|
||||
device = torch.device(args.device)
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
print(f"Using device: {device}")
|
||||
|
||||
print("Creating data loaders...")
|
||||
train_loader, val_loader = create_data_loaders(
|
||||
root_dir=config["data_dir"],
|
||||
batch_size=config["batch_size"],
|
||||
train_split=config["train_split"],
|
||||
num_workers=config["num_workers"],
|
||||
image_size=config["image_size"],
|
||||
root_dir=config_dict["data_dir"],
|
||||
batch_size=config_dict["batch_size"],
|
||||
train_split=config_dict["train_split"],
|
||||
num_workers=config_dict["num_workers"],
|
||||
image_size=config_dict["image_size"],
|
||||
augment_train=True,
|
||||
augment_val=False,
|
||||
device=device,
|
||||
@@ -247,7 +218,9 @@ def main():
|
||||
print("Creating model...")
|
||||
model = create_similarity_model(
|
||||
model_type="cnn",
|
||||
input_size=config["image_size"],
|
||||
input_size=config_dict["image_size"][0]
|
||||
if isinstance(config_dict["image_size"], (tuple, list))
|
||||
else config_dict["image_size"],
|
||||
input_channels=3,
|
||||
hidden_channels=64,
|
||||
num_blocks=4,
|
||||
@@ -262,11 +235,11 @@ def main():
|
||||
train_loader=train_loader,
|
||||
val_loader=val_loader,
|
||||
device=device,
|
||||
config=config,
|
||||
config=config_dict,
|
||||
)
|
||||
|
||||
print("Starting training...")
|
||||
trainer.train(config["epochs"])
|
||||
trainer.train(config_dict["epochs"])
|
||||
|
||||
print("Training completed!")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user