improve schema
This commit is contained in:
@@ -42,6 +42,7 @@
|
||||
" \"dropout_rate\": 0.3,\n",
|
||||
" \"backbone\": \"resnet18\",\n",
|
||||
" \"output_dir\": r\"C:\\Users\\admin\\Projects\\autopilot\\models\\SiaN\\runs\",\n",
|
||||
" \"save_every_n_epochs\": 15,\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"\n",
|
||||
@@ -88,7 +89,7 @@
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": "# SiaN Model\n"
|
||||
"source": "# Configuration\n\nGlobal settings for:\n- Data paths and image parameters\n- Training hyperparameters\n- Model architecture options\n"
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
@@ -185,7 +186,7 @@
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": "Dataset for Google/Yandex image pairs with homography augmentation.\n"
|
||||
"source": "## Dataset\n\nGoogle/Yandex image pair loader with homography augmentation.\n\n**Features:**\n- Loads paired images from dual camera sources\n- Applies random homography transformations\n- Supports configurable train/val split\n\n**Returns:**\n"
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
@@ -237,7 +238,7 @@
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": "HomographyCNN6 predicts 6 params: rx, ry, rz, tx, ty, scale.\n"
|
||||
"source": "## Model\n\n`HomographyCNN6` — CNN architecture for homography estimation.\n\n**Output:** 6 parameters\n- `rx, ry, rz` — rotation angles (radians)\n- `tx, ty` — translation offsets\n- `scale` — isotropic scale factor\n\n**Architecture:**\n- Dual-branch CNN (Google + Yandex images)\n- Shared backbone (configurable: resnet18/34/50)\n"
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
@@ -310,6 +311,10 @@
|
||||
" self.save_checkpoint(epoch, is_best=True)\n",
|
||||
" print(f\"Best model saved (val loss: {val_metrics['loss']:.4f})\")\n",
|
||||
"\n",
|
||||
" if epoch % config[\"save_every_n_epochs\"] == 0:\n",
|
||||
" self.save_checkpoint(epoch, is_best=False)\n",
|
||||
" print(f\"Checkpoint saved at epoch {epoch}\")\n",
|
||||
"\n",
|
||||
" self.writer.close()\n",
|
||||
"\n",
|
||||
" def save_checkpoint(self, epoch, is_best=False):\n",
|
||||
@@ -325,7 +330,7 @@
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": "HomographyTrainer manages training loop with validation.\n"
|
||||
"source": "## Training\n\n`HomographyTrainer` — training loop with validation and checkpointing.\n\n**Features:**\n- Epoch-based training with tqdm progress bar\n- Adam optimizer with configurable LR\n- Validation after each epoch\n- Best model auto-save\n- Periodic checkpoints (every N epochs via `save_every_n_epochs`)\n\n**Checkpoint saving:**\n- `best_model.pt` — lowest validation loss\n"
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
@@ -388,7 +393,7 @@
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": "Visualization and analysis of predictions.\n"
|
||||
"source": "## Analysis\n\nVisualization and evaluation tools:\n\n- Training metrics plots (loss curves)\n- Prediction visualization on sample images\n"
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
@@ -403,75 +408,52 @@
|
||||
"logging.basicConfig(level=logging.INFO, format=\"%(asctime)s - %(message)s\")\n",
|
||||
"logger = logging.getLogger(__name__)\n",
|
||||
"\n",
|
||||
"logger.info(\"=\" * 50)\n",
|
||||
"logger.info(\"SiaN Training Pipeline\")\n",
|
||||
"logger.info(\"=\" * 50)\n",
|
||||
"\n",
|
||||
"def create_dataset():\n",
|
||||
" logger.info(\"Creating data loaders...\")\n",
|
||||
" train_loader, val_loader = create_data_loaders(\n",
|
||||
" root_dir=config[\"data_dir\"],\n",
|
||||
" batch_size=config[\"batch_size\"],\n",
|
||||
" train_split=config[\"train_split\"],\n",
|
||||
" num_workers=config[\"num_workers\"],\n",
|
||||
" image_size=config[\"image_size\"],\n",
|
||||
" )\n",
|
||||
" logger.info(f\"Data loaders created: train={len(train_loader.dataset)}, val={len(val_loader.dataset)}\")\n",
|
||||
" return train_loader, val_loader\n",
|
||||
"dataset_info = get_dataset_info()\n",
|
||||
"logger.info(f\"Dataset: {dataset_info['size']} samples, keys={dataset_info['sample_keys']}\")\n",
|
||||
"\n",
|
||||
"train_loader, val_loader = create_data_loaders(\n",
|
||||
" root_dir=config[\"data_dir\"],\n",
|
||||
" batch_size=config[\"batch_size\"],\n",
|
||||
" train_split=config[\"train_split\"],\n",
|
||||
" num_workers=config[\"num_workers\"],\n",
|
||||
" image_size=config[\"image_size\"],\n",
|
||||
")\n",
|
||||
"logger.info(f\"Data loaders created: train={len(train_loader.dataset)}, val={len(val_loader.dataset)}\")\n",
|
||||
"\n",
|
||||
"def create_model():\n",
|
||||
" logger.info(\"Creating model...\")\n",
|
||||
" model = HomographyCNN6(\n",
|
||||
" input_channels=3,\n",
|
||||
" backbone_name=config[\"backbone\"],\n",
|
||||
" pretrained=True,\n",
|
||||
" dropout_rate=config[\"dropout_rate\"]\n",
|
||||
" )\n",
|
||||
" logger.info(f\"Model created with {count_parameters(model):,} parameters\")\n",
|
||||
" return model\n",
|
||||
"model = HomographyCNN6(\n",
|
||||
" input_channels=3,\n",
|
||||
" backbone_name=config[\"backbone\"],\n",
|
||||
" pretrained=True,\n",
|
||||
" dropout_rate=config[\"dropout_rate\"]\n",
|
||||
")\n",
|
||||
"logger.info(f\"Model created with {count_parameters(model):,} parameters\")\n",
|
||||
"\n",
|
||||
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
||||
"logger.info(f\"Using device: {device}\")\n",
|
||||
"\n",
|
||||
"def train_model(model, train_loader, val_loader):\n",
|
||||
" device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
||||
" logger.info(f\"Using device: {device}\")\n",
|
||||
" \n",
|
||||
" trainer = HomographyTrainer(model, train_loader, val_loader, device)\n",
|
||||
" logger.info(\"Starting training...\")\n",
|
||||
" trainer.train(config[\"epochs\"])\n",
|
||||
" logger.info(\"Training completed\")\n",
|
||||
" return trainer\n",
|
||||
"trainer = HomographyTrainer(model, train_loader, val_loader, device)\n",
|
||||
"logger.info(\"Starting training...\")\n",
|
||||
"trainer.train(config[\"epochs\"])\n",
|
||||
"logger.info(\"Training completed\")\n",
|
||||
"\n",
|
||||
"logger.info(\"Analyzing model...\")\n",
|
||||
"results = analyze_training(trainer)\n",
|
||||
"logger.info(f\"Analysis complete: best_val_loss={results['best_val_loss']:.4f}\")\n",
|
||||
"\n",
|
||||
"def analyze_model(trainer):\n",
|
||||
" logger.info(\"Analyzing model...\")\n",
|
||||
" results = analyze_training(trainer)\n",
|
||||
" logger.info(f\"Analysis complete: best_val_loss={results['best_val_loss']:.4f}\")\n",
|
||||
" return results\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def main():\n",
|
||||
" logger.info(\"=\" * 50)\n",
|
||||
" logger.info(\"SiaN Training Pipeline\")\n",
|
||||
" logger.info(\"=\" * 50)\n",
|
||||
" \n",
|
||||
" dataset_info = get_dataset_info()\n",
|
||||
" logger.info(f\"Dataset: {dataset_info['size']} samples, keys={dataset_info['sample_keys']}\")\n",
|
||||
" \n",
|
||||
" train_loader, val_loader = create_dataset()\n",
|
||||
" model = create_model()\n",
|
||||
" trainer = train_model(model, train_loader, val_loader)\n",
|
||||
" results = analyze_model(trainer)\n",
|
||||
" \n",
|
||||
" logger.info(\"=\" * 50)\n",
|
||||
" logger.info(\"Pipeline completed successfully\")\n",
|
||||
" logger.info(\"=\" * 50)\n",
|
||||
" \n",
|
||||
" return trainer, results\n",
|
||||
"logger.info(\"=\" * 50)\n",
|
||||
"logger.info(\"Pipeline completed successfully\")\n",
|
||||
"logger.info(\"=\" * 50)\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": "Run main() to execute the full pipeline.\n"
|
||||
"source": "## Main Pipeline\n\nExecutes the full training workflow:\n1. Load dataset info\n2. Create data loaders\n3. Initialize model\n4. Train with validation\n5. Analyze and export results\n\n**Outputs:**\n- Model checkpoints in `runs/checkpoints/`\n- TensorBoard logs in `runs/`\n"
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
@@ -479,8 +461,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"if __name__ == \"__main__\":\n",
|
||||
" main()\n",
|
||||
"!zip artefacts.zip runs/gan_training/checkpoints/best_model.pt\n",
|
||||
"\n"
|
||||
]
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user