improve schema

This commit is contained in:
2026-04-04 22:57:41 +03:00
parent b2cc714d79
commit ec8b3ae20e
9 changed files with 199 additions and 146 deletions

View File

@@ -42,6 +42,7 @@
" \"dropout_rate\": 0.3,\n",
" \"backbone\": \"resnet18\",\n",
" \"output_dir\": r\"C:\\Users\\admin\\Projects\\autopilot\\models\\SiaN\\runs\",\n",
" \"save_every_n_epochs\": 15,\n",
"}\n",
"\n",
"\n",
@@ -88,7 +89,7 @@
{
"cell_type": "markdown",
"metadata": {},
"source": "# SiaN Model\n"
"source": "# Configuration\n\nGlobal settings for:\n- Data paths and image parameters\n- Training hyperparameters\n- Model architecture options\n"
},
{
"cell_type": "code",
@@ -185,7 +186,7 @@
{
"cell_type": "markdown",
"metadata": {},
"source": "Dataset for Google/Yandex image pairs with homography augmentation.\n"
"source": "## Dataset\n\nGoogle/Yandex image pair loader with homography augmentation.\n\n**Features:**\n- Loads paired images from dual camera sources\n- Applies random homography transformations\n- Supports configurable train/val split\n\n**Returns:**\n"
},
{
"cell_type": "code",
@@ -237,7 +238,7 @@
{
"cell_type": "markdown",
"metadata": {},
"source": "HomographyCNN6 predicts 6 params: rx, ry, rz, tx, ty, scale.\n"
"source": "## Model\n\n`HomographyCNN6` — CNN architecture for homography estimation.\n\n**Output:** 6 parameters\n- `rx, ry, rz` — rotation angles (radians)\n- `tx, ty` — translation offsets\n- `scale` — isotropic scale factor\n\n**Architecture:**\n- Dual-branch CNN (Google + Yandex images)\n- Shared backbone (configurable: resnet18/34/50)\n"
},
{
"cell_type": "code",
@@ -310,6 +311,10 @@
" self.save_checkpoint(epoch, is_best=True)\n",
" print(f\"Best model saved (val loss: {val_metrics['loss']:.4f})\")\n",
"\n",
" if epoch % config[\"save_every_n_epochs\"] == 0:\n",
" self.save_checkpoint(epoch, is_best=False)\n",
" print(f\"Checkpoint saved at epoch {epoch}\")\n",
"\n",
" self.writer.close()\n",
"\n",
" def save_checkpoint(self, epoch, is_best=False):\n",
@@ -325,7 +330,7 @@
{
"cell_type": "markdown",
"metadata": {},
"source": "HomographyTrainer manages training loop with validation.\n"
"source": "## Training\n\n`HomographyTrainer` — training loop with validation and checkpointing.\n\n**Features:**\n- Epoch-based training with tqdm progress bar\n- Adam optimizer with configurable LR\n- Validation after each epoch\n- Best model auto-save\n- Periodic checkpoints (every N epochs via `save_every_n_epochs`)\n\n**Checkpoint saving:**\n- `best_model.pt` — lowest validation loss\n"
},
{
"cell_type": "code",
@@ -388,7 +393,7 @@
{
"cell_type": "markdown",
"metadata": {},
"source": "Visualization and analysis of predictions.\n"
"source": "## Analysis\n\nVisualization and evaluation tools:\n\n- Training metrics plots (loss curves)\n- Prediction visualization on sample images\n"
},
{
"cell_type": "code",
@@ -403,75 +408,52 @@
"logging.basicConfig(level=logging.INFO, format=\"%(asctime)s - %(message)s\")\n",
"logger = logging.getLogger(__name__)\n",
"\n",
"logger.info(\"=\" * 50)\n",
"logger.info(\"SiaN Training Pipeline\")\n",
"logger.info(\"=\" * 50)\n",
"\n",
"def create_dataset():\n",
" logger.info(\"Creating data loaders...\")\n",
" train_loader, val_loader = create_data_loaders(\n",
" root_dir=config[\"data_dir\"],\n",
" batch_size=config[\"batch_size\"],\n",
" train_split=config[\"train_split\"],\n",
" num_workers=config[\"num_workers\"],\n",
" image_size=config[\"image_size\"],\n",
" )\n",
" logger.info(f\"Data loaders created: train={len(train_loader.dataset)}, val={len(val_loader.dataset)}\")\n",
" return train_loader, val_loader\n",
"dataset_info = get_dataset_info()\n",
"logger.info(f\"Dataset: {dataset_info['size']} samples, keys={dataset_info['sample_keys']}\")\n",
"\n",
"train_loader, val_loader = create_data_loaders(\n",
" root_dir=config[\"data_dir\"],\n",
" batch_size=config[\"batch_size\"],\n",
" train_split=config[\"train_split\"],\n",
" num_workers=config[\"num_workers\"],\n",
" image_size=config[\"image_size\"],\n",
")\n",
"logger.info(f\"Data loaders created: train={len(train_loader.dataset)}, val={len(val_loader.dataset)}\")\n",
"\n",
"def create_model():\n",
" logger.info(\"Creating model...\")\n",
" model = HomographyCNN6(\n",
" input_channels=3,\n",
" backbone_name=config[\"backbone\"],\n",
" pretrained=True,\n",
" dropout_rate=config[\"dropout_rate\"]\n",
" )\n",
" logger.info(f\"Model created with {count_parameters(model):,} parameters\")\n",
" return model\n",
"model = HomographyCNN6(\n",
" input_channels=3,\n",
" backbone_name=config[\"backbone\"],\n",
" pretrained=True,\n",
" dropout_rate=config[\"dropout_rate\"]\n",
")\n",
"logger.info(f\"Model created with {count_parameters(model):,} parameters\")\n",
"\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"logger.info(f\"Using device: {device}\")\n",
"\n",
"def train_model(model, train_loader, val_loader):\n",
" device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
" logger.info(f\"Using device: {device}\")\n",
" \n",
" trainer = HomographyTrainer(model, train_loader, val_loader, device)\n",
" logger.info(\"Starting training...\")\n",
" trainer.train(config[\"epochs\"])\n",
" logger.info(\"Training completed\")\n",
" return trainer\n",
"trainer = HomographyTrainer(model, train_loader, val_loader, device)\n",
"logger.info(\"Starting training...\")\n",
"trainer.train(config[\"epochs\"])\n",
"logger.info(\"Training completed\")\n",
"\n",
"logger.info(\"Analyzing model...\")\n",
"results = analyze_training(trainer)\n",
"logger.info(f\"Analysis complete: best_val_loss={results['best_val_loss']:.4f}\")\n",
"\n",
"def analyze_model(trainer):\n",
" logger.info(\"Analyzing model...\")\n",
" results = analyze_training(trainer)\n",
" logger.info(f\"Analysis complete: best_val_loss={results['best_val_loss']:.4f}\")\n",
" return results\n",
"\n",
"\n",
"def main():\n",
" logger.info(\"=\" * 50)\n",
" logger.info(\"SiaN Training Pipeline\")\n",
" logger.info(\"=\" * 50)\n",
" \n",
" dataset_info = get_dataset_info()\n",
" logger.info(f\"Dataset: {dataset_info['size']} samples, keys={dataset_info['sample_keys']}\")\n",
" \n",
" train_loader, val_loader = create_dataset()\n",
" model = create_model()\n",
" trainer = train_model(model, train_loader, val_loader)\n",
" results = analyze_model(trainer)\n",
" \n",
" logger.info(\"=\" * 50)\n",
" logger.info(\"Pipeline completed successfully\")\n",
" logger.info(\"=\" * 50)\n",
" \n",
" return trainer, results\n",
"logger.info(\"=\" * 50)\n",
"logger.info(\"Pipeline completed successfully\")\n",
"logger.info(\"=\" * 50)\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": "Run main() to execute the full pipeline.\n"
"source": "## Main Pipeline\n\nExecutes the full training workflow:\n1. Load dataset info\n2. Create data loaders\n3. Initialize model\n4. Train with validation\n5. Analyze and export results\n\n**Outputs:**\n- Model checkpoints in `runs/checkpoints/`\n- TensorBoard logs in `runs/`\n"
},
{
"cell_type": "code",
@@ -479,8 +461,7 @@
"metadata": {},
"outputs": [],
"source": [
"if __name__ == \"__main__\":\n",
" main()\n",
"!zip artefacts.zip runs/gan_training/checkpoints/best_model.pt\n",
"\n"
]
}