feat: complete sian-similarity

This commit is contained in:
2026-03-22 14:29:00 +03:00
parent 43cd4222bc
commit 05f8746d58
8 changed files with 3780 additions and 1903 deletions

View File

@@ -2,7 +2,6 @@
Training script for image similarity estimation.
"""
import argparse
import os
import time
from datetime import datetime
@@ -10,7 +9,7 @@ from datetime import datetime
import torch
import torch.nn as nn
import torch.optim as optim
from dataloader import create_data_loaders
from dataloader import config, create_data_loaders
from model import SimilarityCNN, SimilarityLoss, create_similarity_model
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
@@ -191,51 +190,23 @@ class SimilarityTrainer:
def main():
parser = argparse.ArgumentParser(description="Train similarity estimation model")
parser.add_argument(
"--data_dir",
type=str,
default=r"C:\Users\admin\Projects\autopilot\datasets\ya_go_maps\images",
)
parser.add_argument("--batch_size", type=int, default=32)
parser.add_argument("--epochs", type=int, default=100)
parser.add_argument("--learning_rate", type=float, default=2e-4)
parser.add_argument("--image_size", type=int, default=256)
parser.add_argument("--train_split", type=float, default=0.8)
parser.add_argument("--output_dir", type=str, default="runs/similarity")
parser.add_argument("--num_workers", type=int, default=0)
parser.add_argument(
"--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu"
)
# Use config from dataloader.py
config_dict = config.copy()
args = parser.parse_args()
# Ensure image_size is tuple
if isinstance(config_dict.get("image_size"), list):
config_dict["image_size"] = tuple(config_dict["image_size"])
config = {
"data_dir": args.data_dir,
"batch_size": args.batch_size,
"epochs": args.epochs,
"learning_rate": args.learning_rate,
"image_size": (args.image_size, args.image_size),
"train_split": args.train_split,
"output_dir": args.output_dir,
"num_workers": args.num_workers,
"log_interval": 10,
"save_interval": 5,
"early_stopping_patience": 20,
"beta1": 0.5,
"beta2": 0.999,
}
device = torch.device(args.device)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
print("Creating data loaders...")
train_loader, val_loader = create_data_loaders(
root_dir=config["data_dir"],
batch_size=config["batch_size"],
train_split=config["train_split"],
num_workers=config["num_workers"],
image_size=config["image_size"],
root_dir=config_dict["data_dir"],
batch_size=config_dict["batch_size"],
train_split=config_dict["train_split"],
num_workers=config_dict["num_workers"],
image_size=config_dict["image_size"],
augment_train=True,
augment_val=False,
device=device,
@@ -247,7 +218,9 @@ def main():
print("Creating model...")
model = create_similarity_model(
model_type="cnn",
input_size=config["image_size"],
input_size=config_dict["image_size"][0]
if isinstance(config_dict["image_size"], (tuple, list))
else config_dict["image_size"],
input_channels=3,
hidden_channels=64,
num_blocks=4,
@@ -262,11 +235,11 @@ def main():
train_loader=train_loader,
val_loader=val_loader,
device=device,
config=config,
config=config_dict,
)
print("Starting training...")
trainer.train(config["epochs"])
trainer.train(config_dict["epochs"])
print("Training completed!")