improve schema

This commit is contained in:
2026-04-04 22:57:41 +03:00
parent b2cc714d79
commit ec8b3ae20e
9 changed files with 199 additions and 146 deletions

View File

@@ -78,7 +78,7 @@ def create_data_loaders(root_dir, batch_size=32, train_split=0.8, num_workers=0,
split = int(train_split * len(indices))
train_ds = Subset(aug_ds if augment_train else full_ds, indices[:split])
val_ds = Subset(full_ds, indices[split:])
val_ds = Subset(aug_ds, indices[split:])
return (DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True),
DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True))

View File

@@ -1,4 +1,5 @@
import logging
import os
import torch
@@ -12,69 +13,42 @@ from .utils import config
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(message)s")
logger = logging.getLogger(__name__)
logger.info("=" * 50)
logger.info("SiaN Training Pipeline")
logger.info("=" * 50)
def create_dataset():
logger.info("Creating data loaders...")
train_loader, val_loader = create_data_loaders(
root_dir=config["data_dir"],
batch_size=config["batch_size"],
train_split=config["train_split"],
num_workers=config["num_workers"],
image_size=config["image_size"],
)
logger.info(f"Data loaders created: train={len(train_loader.dataset)}, val={len(val_loader.dataset)}")
return train_loader, val_loader
dataset_info = get_dataset_info()
logger.info(f"Dataset: {dataset_info['size']} samples, keys={dataset_info['sample_keys']}")
train_loader, val_loader = create_data_loaders(
root_dir=config["data_dir"],
batch_size=config["batch_size"],
train_split=config["train_split"],
num_workers=config["num_workers"],
image_size=config["image_size"],
)
logger.info(f"Data loaders created: train={len(train_loader.dataset)}, val={len(val_loader.dataset)}")
def create_model():
logger.info("Creating model...")
model = HomographyCNN6(
input_channels=3,
backbone_name=config["backbone"],
pretrained=True,
dropout_rate=config["dropout_rate"]
)
logger.info(f"Model created with {count_parameters(model):,} parameters")
return model
model = HomographyCNN6(
input_channels=3,
backbone_name=config["backbone"],
pretrained=True,
dropout_rate=config["dropout_rate"]
)
logger.info(f"Model created with {count_parameters(model):,} parameters")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {device}")
def train_model(model, train_loader, val_loader):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {device}")
trainer = HomographyTrainer(model, train_loader, val_loader, device)
logger.info("Starting training...")
trainer.train(config["epochs"])
logger.info("Training completed")
return trainer
trainer = HomographyTrainer(model, train_loader, val_loader, device)
logger.info("Starting training...")
trainer.train(config["epochs"])
logger.info("Training completed")
logger.info("Analyzing model...")
results = analyze_training(trainer)
logger.info(f"Analysis complete: best_val_loss={results['best_val_loss']:.4f}")
def analyze_model(trainer):
logger.info("Analyzing model...")
results = analyze_training(trainer)
logger.info(f"Analysis complete: best_val_loss={results['best_val_loss']:.4f}")
return results
def main():
logger.info("=" * 50)
logger.info("SiaN Training Pipeline")
logger.info("=" * 50)
dataset_info = get_dataset_info()
logger.info(f"Dataset: {dataset_info['size']} samples, keys={dataset_info['sample_keys']}")
train_loader, val_loader = create_dataset()
model = create_model()
trainer = train_model(model, train_loader, val_loader)
results = analyze_model(trainer)
logger.info("=" * 50)
logger.info("Pipeline completed successfully")
logger.info("=" * 50)
return trainer, results
if __name__ == "__main__":
main()
logger.info("=" * 50)
logger.info("Pipeline completed successfully")
logger.info("=" * 50)

View File

@@ -72,6 +72,10 @@ class HomographyTrainer:
self.save_checkpoint(epoch, is_best=True)
print(f"Best model saved (val loss: {val_metrics['loss']:.4f})")
if epoch % config["save_every_n_epochs"] == 0:
self.save_checkpoint(epoch, is_best=False)
print(f"Checkpoint saved at epoch {epoch}")
self.writer.close()
def save_checkpoint(self, epoch, is_best=False):

View File

@@ -12,6 +12,7 @@ config = {
"dropout_rate": 0.3,
"backbone": "resnet18",
"output_dir": r"C:\Users\admin\Projects\autopilot\models\SiaN\runs",
"save_every_n_epochs": 15,
}