|
|
|
|
@@ -1,4 +1,5 @@
|
|
|
|
|
import logging
|
|
|
|
|
import os
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
|
|
|
|
|
@@ -12,69 +13,42 @@ from .utils import config
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(message)s")
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
logger.info("=" * 50)
|
|
|
|
|
logger.info("SiaN Training Pipeline")
|
|
|
|
|
logger.info("=" * 50)
|
|
|
|
|
|
|
|
|
|
def create_dataset():
|
|
|
|
|
logger.info("Creating data loaders...")
|
|
|
|
|
train_loader, val_loader = create_data_loaders(
|
|
|
|
|
root_dir=config["data_dir"],
|
|
|
|
|
batch_size=config["batch_size"],
|
|
|
|
|
train_split=config["train_split"],
|
|
|
|
|
num_workers=config["num_workers"],
|
|
|
|
|
image_size=config["image_size"],
|
|
|
|
|
)
|
|
|
|
|
logger.info(f"Data loaders created: train={len(train_loader.dataset)}, val={len(val_loader.dataset)}")
|
|
|
|
|
return train_loader, val_loader
|
|
|
|
|
dataset_info = get_dataset_info()
|
|
|
|
|
logger.info(f"Dataset: {dataset_info['size']} samples, keys={dataset_info['sample_keys']}")
|
|
|
|
|
|
|
|
|
|
train_loader, val_loader = create_data_loaders(
|
|
|
|
|
root_dir=config["data_dir"],
|
|
|
|
|
batch_size=config["batch_size"],
|
|
|
|
|
train_split=config["train_split"],
|
|
|
|
|
num_workers=config["num_workers"],
|
|
|
|
|
image_size=config["image_size"],
|
|
|
|
|
)
|
|
|
|
|
logger.info(f"Data loaders created: train={len(train_loader.dataset)}, val={len(val_loader.dataset)}")
|
|
|
|
|
|
|
|
|
|
def create_model():
|
|
|
|
|
logger.info("Creating model...")
|
|
|
|
|
model = HomographyCNN6(
|
|
|
|
|
input_channels=3,
|
|
|
|
|
backbone_name=config["backbone"],
|
|
|
|
|
pretrained=True,
|
|
|
|
|
dropout_rate=config["dropout_rate"]
|
|
|
|
|
)
|
|
|
|
|
logger.info(f"Model created with {count_parameters(model):,} parameters")
|
|
|
|
|
return model
|
|
|
|
|
model = HomographyCNN6(
|
|
|
|
|
input_channels=3,
|
|
|
|
|
backbone_name=config["backbone"],
|
|
|
|
|
pretrained=True,
|
|
|
|
|
dropout_rate=config["dropout_rate"]
|
|
|
|
|
)
|
|
|
|
|
logger.info(f"Model created with {count_parameters(model):,} parameters")
|
|
|
|
|
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
|
logger.info(f"Using device: {device}")
|
|
|
|
|
|
|
|
|
|
def train_model(model, train_loader, val_loader):
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
|
logger.info(f"Using device: {device}")
|
|
|
|
|
|
|
|
|
|
trainer = HomographyTrainer(model, train_loader, val_loader, device)
|
|
|
|
|
logger.info("Starting training...")
|
|
|
|
|
trainer.train(config["epochs"])
|
|
|
|
|
logger.info("Training completed")
|
|
|
|
|
return trainer
|
|
|
|
|
trainer = HomographyTrainer(model, train_loader, val_loader, device)
|
|
|
|
|
logger.info("Starting training...")
|
|
|
|
|
trainer.train(config["epochs"])
|
|
|
|
|
logger.info("Training completed")
|
|
|
|
|
|
|
|
|
|
logger.info("Analyzing model...")
|
|
|
|
|
results = analyze_training(trainer)
|
|
|
|
|
logger.info(f"Analysis complete: best_val_loss={results['best_val_loss']:.4f}")
|
|
|
|
|
|
|
|
|
|
def analyze_model(trainer):
|
|
|
|
|
logger.info("Analyzing model...")
|
|
|
|
|
results = analyze_training(trainer)
|
|
|
|
|
logger.info(f"Analysis complete: best_val_loss={results['best_val_loss']:.4f}")
|
|
|
|
|
return results
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main():
|
|
|
|
|
logger.info("=" * 50)
|
|
|
|
|
logger.info("SiaN Training Pipeline")
|
|
|
|
|
logger.info("=" * 50)
|
|
|
|
|
|
|
|
|
|
dataset_info = get_dataset_info()
|
|
|
|
|
logger.info(f"Dataset: {dataset_info['size']} samples, keys={dataset_info['sample_keys']}")
|
|
|
|
|
|
|
|
|
|
train_loader, val_loader = create_dataset()
|
|
|
|
|
model = create_model()
|
|
|
|
|
trainer = train_model(model, train_loader, val_loader)
|
|
|
|
|
results = analyze_model(trainer)
|
|
|
|
|
|
|
|
|
|
logger.info("=" * 50)
|
|
|
|
|
logger.info("Pipeline completed successfully")
|
|
|
|
|
logger.info("=" * 50)
|
|
|
|
|
|
|
|
|
|
return trainer, results
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
main()
|
|
|
|
|
logger.info("=" * 50)
|
|
|
|
|
logger.info("Pipeline completed successfully")
|
|
|
|
|
logger.info("=" * 50)
|
|
|
|
|
|