improve schema
This commit is contained in:
3
models/SiaN/.gitignore
vendored
3
models/SiaN/.gitignore
vendored
@@ -1,2 +1,3 @@
|
||||
runs
|
||||
*.gen.py
|
||||
*.gen.py
|
||||
*.img
|
||||
@@ -1,4 +1,8 @@
|
||||
# _schema.md
|
||||
# SiaN Schema
|
||||
|
||||
Notebook structure definition for SiaN model.
|
||||
|
||||
---
|
||||
|
||||
## Format
|
||||
|
||||
@@ -10,14 +14,18 @@
|
||||
# markdown
|
||||
"""Description"""
|
||||
|
||||
# inline:
|
||||
<custom code>
|
||||
# shell:
|
||||
<shell commands>
|
||||
```
|
||||
|
||||
## Directives
|
||||
|
||||
- `# code:` - include file from src/
|
||||
- `# markdown` - description block
|
||||
- `# inline:` - custom code cell
|
||||
| Directive | Description |
|
||||
|----------------|------------------------------------|
|
||||
| `# code:` | Include file from src/ |
|
||||
| `# markdown` | Description block |
|
||||
| `# shell:` | Shell script cell |
|
||||
|
||||
---
|
||||
|
||||
`build.py` generates notebook from `_schema.py`.
|
||||
|
||||
@@ -20,28 +20,87 @@ from tqdm import tqdm
|
||||
|
||||
# code: ./src/utils.py
|
||||
# markdown
|
||||
"""# SiaN Model"""
|
||||
"""# Configuration
|
||||
|
||||
Global settings for:
|
||||
- Data paths and image parameters
|
||||
- Training hyperparameters
|
||||
- Model architecture options
|
||||
|
||||
Contains the `config` dictionary used across all modules."""
|
||||
|
||||
# code: ./src/dataloader.py
|
||||
# markdown
|
||||
"""Dataset for Google/Yandex image pairs with homography augmentation."""
|
||||
"""## Dataset
|
||||
|
||||
Google/Yandex image pair loader with homography augmentation.
|
||||
|
||||
**Features:**
|
||||
- Loads paired images from dual camera sources
|
||||
- Applies random homography transformations
|
||||
- Supports configurable train/val split
|
||||
|
||||
**Returns:**
|
||||
- Batch dict with `google_img`, `yandex_img`, `homography_params`"""
|
||||
|
||||
# code: ./src/model.py
|
||||
# markdown
|
||||
"""HomographyCNN6 predicts 6 params: rx, ry, rz, tx, ty, scale."""
|
||||
"""## Model
|
||||
|
||||
`HomographyCNN6` — CNN architecture for homography estimation.
|
||||
|
||||
**Output:** 6 parameters
|
||||
- `rx, ry, rz` — rotation angles (radians)
|
||||
- `tx, ty` — translation offsets
|
||||
- `scale` — isotropic scale factor
|
||||
|
||||
**Architecture:**
|
||||
- Dual-branch CNN (Google + Yandex images)
|
||||
- Shared backbone (configurable: resnet18/34/50)
|
||||
- Fusion head with dropout regularization"""
|
||||
|
||||
# code: ./src/train.py
|
||||
# markdown
|
||||
"""HomographyTrainer manages training loop with validation."""
|
||||
"""## Training
|
||||
|
||||
`HomographyTrainer` — training loop with validation and checkpointing.
|
||||
|
||||
**Features:**
|
||||
- Epoch-based training with tqdm progress bar
|
||||
- Adam optimizer with configurable LR
|
||||
- Validation after each epoch
|
||||
- Best model auto-save
|
||||
- Periodic checkpoints (every N epochs via `save_every_n_epochs`)
|
||||
|
||||
**Checkpoint saving:**
|
||||
- `best_model.pt` — lowest validation loss
|
||||
- `checkpoint_epoch_N.pt` — periodic saves"""
|
||||
|
||||
# code: ./src/analyze.py
|
||||
# markdown
|
||||
"""Visualization and analysis of predictions."""
|
||||
"""## Analysis
|
||||
|
||||
Visualization and evaluation tools:
|
||||
|
||||
- Training metrics plots (loss curves)
|
||||
- Prediction visualization on sample images
|
||||
- Error analysis and statistics"""
|
||||
|
||||
# code: ./src/main.py
|
||||
# markdown
|
||||
"""Run main() to execute the full pipeline."""
|
||||
"""## Main Pipeline
|
||||
|
||||
# inline:
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Executes the full training workflow:
|
||||
1. Load dataset info
|
||||
2. Create data loaders
|
||||
3. Initialize model
|
||||
4. Train with validation
|
||||
5. Analyze and export results
|
||||
|
||||
**Outputs:**
|
||||
- Model checkpoints in `runs/checkpoints/`
|
||||
- TensorBoard logs in `runs/`
|
||||
- Analysis plots"""
|
||||
|
||||
# # shell:
|
||||
# !zip artefacts.zip runs/gan_training/checkpoints/best_model.pt
|
||||
|
||||
@@ -42,23 +42,34 @@ def parse_schema(schema_path):
|
||||
i += 1
|
||||
continue
|
||||
|
||||
if stripped.startswith('# inline:'):
|
||||
if stripped.startswith('# inline:') or stripped.startswith('# # shell:'):
|
||||
directive = 'inline' if stripped.startswith('# inline:') else 'shell'
|
||||
i += 1
|
||||
code_lines = []
|
||||
while i < len(lines):
|
||||
stripped = lines[i].strip()
|
||||
if stripped.startswith('#'):
|
||||
if stripped.startswith('# code:') or stripped.startswith('# inline:') or stripped == '# markdown' or stripped.startswith('# ==='):
|
||||
break
|
||||
i += 1
|
||||
continue
|
||||
if directive == 'shell':
|
||||
if stripped.startswith('# ') or stripped.startswith('# # '):
|
||||
line_content = lines[i].strip()
|
||||
if line_content.startswith('# # '):
|
||||
line_content = line_content[3:].lstrip()
|
||||
else:
|
||||
line_content = line_content[2:]
|
||||
code_lines.append(line_content)
|
||||
i += 1
|
||||
continue
|
||||
if stripped.startswith('#'):
|
||||
if stripped.startswith('# code:') or stripped.startswith('# inline:') or stripped.startswith('# # shell:') or stripped == '# markdown' or stripped.startswith('# ==='):
|
||||
break
|
||||
i += 1
|
||||
continue
|
||||
if stripped == '':
|
||||
i += 1
|
||||
continue
|
||||
code_lines.append(lines[i])
|
||||
i += 1
|
||||
if code_lines:
|
||||
items.append(('inline', '\n'.join(code_lines).rstrip()))
|
||||
items.append((directive, '\n'.join(code_lines).rstrip()))
|
||||
continue
|
||||
|
||||
if stripped == '# markdown':
|
||||
@@ -180,6 +191,20 @@ def build_notebook(schema_path, output_path=None):
|
||||
"source": [line + "\n" for line in lines]
|
||||
}
|
||||
cells.append(cell)
|
||||
elif item_type == 'shell':
|
||||
lines = item_content.split('\n')
|
||||
while lines and lines[-1].strip() == '':
|
||||
lines.pop()
|
||||
if lines:
|
||||
lines.append('')
|
||||
cell = {
|
||||
"cell_type": "code",
|
||||
"execution_count": None,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [line + "\n" for line in lines]
|
||||
}
|
||||
cells.append(cell)
|
||||
elif item_type == 'code':
|
||||
try:
|
||||
content = read_src_file(item_content, schema_dir)
|
||||
|
||||
@@ -42,6 +42,7 @@
|
||||
" \"dropout_rate\": 0.3,\n",
|
||||
" \"backbone\": \"resnet18\",\n",
|
||||
" \"output_dir\": r\"C:\\Users\\admin\\Projects\\autopilot\\models\\SiaN\\runs\",\n",
|
||||
" \"save_every_n_epochs\": 15,\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"\n",
|
||||
@@ -88,7 +89,7 @@
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": "# SiaN Model\n"
|
||||
"source": "# Configuration\n\nGlobal settings for:\n- Data paths and image parameters\n- Training hyperparameters\n- Model architecture options\n"
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
@@ -185,7 +186,7 @@
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": "Dataset for Google/Yandex image pairs with homography augmentation.\n"
|
||||
"source": "## Dataset\n\nGoogle/Yandex image pair loader with homography augmentation.\n\n**Features:**\n- Loads paired images from dual camera sources\n- Applies random homography transformations\n- Supports configurable train/val split\n\n**Returns:**\n"
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
@@ -237,7 +238,7 @@
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": "HomographyCNN6 predicts 6 params: rx, ry, rz, tx, ty, scale.\n"
|
||||
"source": "## Model\n\n`HomographyCNN6` — CNN architecture for homography estimation.\n\n**Output:** 6 parameters\n- `rx, ry, rz` — rotation angles (radians)\n- `tx, ty` — translation offsets\n- `scale` — isotropic scale factor\n\n**Architecture:**\n- Dual-branch CNN (Google + Yandex images)\n- Shared backbone (configurable: resnet18/34/50)\n"
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
@@ -310,6 +311,10 @@
|
||||
" self.save_checkpoint(epoch, is_best=True)\n",
|
||||
" print(f\"Best model saved (val loss: {val_metrics['loss']:.4f})\")\n",
|
||||
"\n",
|
||||
" if epoch % config[\"save_every_n_epochs\"] == 0:\n",
|
||||
" self.save_checkpoint(epoch, is_best=False)\n",
|
||||
" print(f\"Checkpoint saved at epoch {epoch}\")\n",
|
||||
"\n",
|
||||
" self.writer.close()\n",
|
||||
"\n",
|
||||
" def save_checkpoint(self, epoch, is_best=False):\n",
|
||||
@@ -325,7 +330,7 @@
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": "HomographyTrainer manages training loop with validation.\n"
|
||||
"source": "## Training\n\n`HomographyTrainer` — training loop with validation and checkpointing.\n\n**Features:**\n- Epoch-based training with tqdm progress bar\n- Adam optimizer with configurable LR\n- Validation after each epoch\n- Best model auto-save\n- Periodic checkpoints (every N epochs via `save_every_n_epochs`)\n\n**Checkpoint saving:**\n- `best_model.pt` — lowest validation loss\n"
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
@@ -388,7 +393,7 @@
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": "Visualization and analysis of predictions.\n"
|
||||
"source": "## Analysis\n\nVisualization and evaluation tools:\n\n- Training metrics plots (loss curves)\n- Prediction visualization on sample images\n"
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
@@ -403,75 +408,52 @@
|
||||
"logging.basicConfig(level=logging.INFO, format=\"%(asctime)s - %(message)s\")\n",
|
||||
"logger = logging.getLogger(__name__)\n",
|
||||
"\n",
|
||||
"logger.info(\"=\" * 50)\n",
|
||||
"logger.info(\"SiaN Training Pipeline\")\n",
|
||||
"logger.info(\"=\" * 50)\n",
|
||||
"\n",
|
||||
"def create_dataset():\n",
|
||||
" logger.info(\"Creating data loaders...\")\n",
|
||||
" train_loader, val_loader = create_data_loaders(\n",
|
||||
" root_dir=config[\"data_dir\"],\n",
|
||||
" batch_size=config[\"batch_size\"],\n",
|
||||
" train_split=config[\"train_split\"],\n",
|
||||
" num_workers=config[\"num_workers\"],\n",
|
||||
" image_size=config[\"image_size\"],\n",
|
||||
" )\n",
|
||||
" logger.info(f\"Data loaders created: train={len(train_loader.dataset)}, val={len(val_loader.dataset)}\")\n",
|
||||
" return train_loader, val_loader\n",
|
||||
"dataset_info = get_dataset_info()\n",
|
||||
"logger.info(f\"Dataset: {dataset_info['size']} samples, keys={dataset_info['sample_keys']}\")\n",
|
||||
"\n",
|
||||
"train_loader, val_loader = create_data_loaders(\n",
|
||||
" root_dir=config[\"data_dir\"],\n",
|
||||
" batch_size=config[\"batch_size\"],\n",
|
||||
" train_split=config[\"train_split\"],\n",
|
||||
" num_workers=config[\"num_workers\"],\n",
|
||||
" image_size=config[\"image_size\"],\n",
|
||||
")\n",
|
||||
"logger.info(f\"Data loaders created: train={len(train_loader.dataset)}, val={len(val_loader.dataset)}\")\n",
|
||||
"\n",
|
||||
"def create_model():\n",
|
||||
" logger.info(\"Creating model...\")\n",
|
||||
" model = HomographyCNN6(\n",
|
||||
" input_channels=3,\n",
|
||||
" backbone_name=config[\"backbone\"],\n",
|
||||
" pretrained=True,\n",
|
||||
" dropout_rate=config[\"dropout_rate\"]\n",
|
||||
" )\n",
|
||||
" logger.info(f\"Model created with {count_parameters(model):,} parameters\")\n",
|
||||
" return model\n",
|
||||
"model = HomographyCNN6(\n",
|
||||
" input_channels=3,\n",
|
||||
" backbone_name=config[\"backbone\"],\n",
|
||||
" pretrained=True,\n",
|
||||
" dropout_rate=config[\"dropout_rate\"]\n",
|
||||
")\n",
|
||||
"logger.info(f\"Model created with {count_parameters(model):,} parameters\")\n",
|
||||
"\n",
|
||||
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
||||
"logger.info(f\"Using device: {device}\")\n",
|
||||
"\n",
|
||||
"def train_model(model, train_loader, val_loader):\n",
|
||||
" device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
||||
" logger.info(f\"Using device: {device}\")\n",
|
||||
" \n",
|
||||
" trainer = HomographyTrainer(model, train_loader, val_loader, device)\n",
|
||||
" logger.info(\"Starting training...\")\n",
|
||||
" trainer.train(config[\"epochs\"])\n",
|
||||
" logger.info(\"Training completed\")\n",
|
||||
" return trainer\n",
|
||||
"trainer = HomographyTrainer(model, train_loader, val_loader, device)\n",
|
||||
"logger.info(\"Starting training...\")\n",
|
||||
"trainer.train(config[\"epochs\"])\n",
|
||||
"logger.info(\"Training completed\")\n",
|
||||
"\n",
|
||||
"logger.info(\"Analyzing model...\")\n",
|
||||
"results = analyze_training(trainer)\n",
|
||||
"logger.info(f\"Analysis complete: best_val_loss={results['best_val_loss']:.4f}\")\n",
|
||||
"\n",
|
||||
"def analyze_model(trainer):\n",
|
||||
" logger.info(\"Analyzing model...\")\n",
|
||||
" results = analyze_training(trainer)\n",
|
||||
" logger.info(f\"Analysis complete: best_val_loss={results['best_val_loss']:.4f}\")\n",
|
||||
" return results\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def main():\n",
|
||||
" logger.info(\"=\" * 50)\n",
|
||||
" logger.info(\"SiaN Training Pipeline\")\n",
|
||||
" logger.info(\"=\" * 50)\n",
|
||||
" \n",
|
||||
" dataset_info = get_dataset_info()\n",
|
||||
" logger.info(f\"Dataset: {dataset_info['size']} samples, keys={dataset_info['sample_keys']}\")\n",
|
||||
" \n",
|
||||
" train_loader, val_loader = create_dataset()\n",
|
||||
" model = create_model()\n",
|
||||
" trainer = train_model(model, train_loader, val_loader)\n",
|
||||
" results = analyze_model(trainer)\n",
|
||||
" \n",
|
||||
" logger.info(\"=\" * 50)\n",
|
||||
" logger.info(\"Pipeline completed successfully\")\n",
|
||||
" logger.info(\"=\" * 50)\n",
|
||||
" \n",
|
||||
" return trainer, results\n",
|
||||
"logger.info(\"=\" * 50)\n",
|
||||
"logger.info(\"Pipeline completed successfully\")\n",
|
||||
"logger.info(\"=\" * 50)\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": "Run main() to execute the full pipeline.\n"
|
||||
"source": "## Main Pipeline\n\nExecutes the full training workflow:\n1. Load dataset info\n2. Create data loaders\n3. Initialize model\n4. Train with validation\n5. Analyze and export results\n\n**Outputs:**\n- Model checkpoints in `runs/checkpoints/`\n- TensorBoard logs in `runs/`\n"
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
@@ -479,8 +461,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"if __name__ == \"__main__\":\n",
|
||||
" main()\n",
|
||||
"!zip artefacts.zip runs/gan_training/checkpoints/best_model.pt\n",
|
||||
"\n"
|
||||
]
|
||||
}
|
||||
|
||||
@@ -78,7 +78,7 @@ def create_data_loaders(root_dir, batch_size=32, train_split=0.8, num_workers=0,
|
||||
split = int(train_split * len(indices))
|
||||
|
||||
train_ds = Subset(aug_ds if augment_train else full_ds, indices[:split])
|
||||
val_ds = Subset(full_ds, indices[split:])
|
||||
val_ds = Subset(aug_ds, indices[split:])
|
||||
|
||||
return (DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True),
|
||||
DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True))
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import logging
|
||||
import os
|
||||
|
||||
import torch
|
||||
|
||||
@@ -12,69 +13,42 @@ from .utils import config
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(message)s")
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
logger.info("=" * 50)
|
||||
logger.info("SiaN Training Pipeline")
|
||||
logger.info("=" * 50)
|
||||
|
||||
def create_dataset():
|
||||
logger.info("Creating data loaders...")
|
||||
train_loader, val_loader = create_data_loaders(
|
||||
root_dir=config["data_dir"],
|
||||
batch_size=config["batch_size"],
|
||||
train_split=config["train_split"],
|
||||
num_workers=config["num_workers"],
|
||||
image_size=config["image_size"],
|
||||
)
|
||||
logger.info(f"Data loaders created: train={len(train_loader.dataset)}, val={len(val_loader.dataset)}")
|
||||
return train_loader, val_loader
|
||||
dataset_info = get_dataset_info()
|
||||
logger.info(f"Dataset: {dataset_info['size']} samples, keys={dataset_info['sample_keys']}")
|
||||
|
||||
train_loader, val_loader = create_data_loaders(
|
||||
root_dir=config["data_dir"],
|
||||
batch_size=config["batch_size"],
|
||||
train_split=config["train_split"],
|
||||
num_workers=config["num_workers"],
|
||||
image_size=config["image_size"],
|
||||
)
|
||||
logger.info(f"Data loaders created: train={len(train_loader.dataset)}, val={len(val_loader.dataset)}")
|
||||
|
||||
def create_model():
|
||||
logger.info("Creating model...")
|
||||
model = HomographyCNN6(
|
||||
input_channels=3,
|
||||
backbone_name=config["backbone"],
|
||||
pretrained=True,
|
||||
dropout_rate=config["dropout_rate"]
|
||||
)
|
||||
logger.info(f"Model created with {count_parameters(model):,} parameters")
|
||||
return model
|
||||
model = HomographyCNN6(
|
||||
input_channels=3,
|
||||
backbone_name=config["backbone"],
|
||||
pretrained=True,
|
||||
dropout_rate=config["dropout_rate"]
|
||||
)
|
||||
logger.info(f"Model created with {count_parameters(model):,} parameters")
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
logger.info(f"Using device: {device}")
|
||||
|
||||
def train_model(model, train_loader, val_loader):
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
logger.info(f"Using device: {device}")
|
||||
|
||||
trainer = HomographyTrainer(model, train_loader, val_loader, device)
|
||||
logger.info("Starting training...")
|
||||
trainer.train(config["epochs"])
|
||||
logger.info("Training completed")
|
||||
return trainer
|
||||
trainer = HomographyTrainer(model, train_loader, val_loader, device)
|
||||
logger.info("Starting training...")
|
||||
trainer.train(config["epochs"])
|
||||
logger.info("Training completed")
|
||||
|
||||
logger.info("Analyzing model...")
|
||||
results = analyze_training(trainer)
|
||||
logger.info(f"Analysis complete: best_val_loss={results['best_val_loss']:.4f}")
|
||||
|
||||
def analyze_model(trainer):
|
||||
logger.info("Analyzing model...")
|
||||
results = analyze_training(trainer)
|
||||
logger.info(f"Analysis complete: best_val_loss={results['best_val_loss']:.4f}")
|
||||
return results
|
||||
|
||||
|
||||
def main():
|
||||
logger.info("=" * 50)
|
||||
logger.info("SiaN Training Pipeline")
|
||||
logger.info("=" * 50)
|
||||
|
||||
dataset_info = get_dataset_info()
|
||||
logger.info(f"Dataset: {dataset_info['size']} samples, keys={dataset_info['sample_keys']}")
|
||||
|
||||
train_loader, val_loader = create_dataset()
|
||||
model = create_model()
|
||||
trainer = train_model(model, train_loader, val_loader)
|
||||
results = analyze_model(trainer)
|
||||
|
||||
logger.info("=" * 50)
|
||||
logger.info("Pipeline completed successfully")
|
||||
logger.info("=" * 50)
|
||||
|
||||
return trainer, results
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
logger.info("=" * 50)
|
||||
logger.info("Pipeline completed successfully")
|
||||
logger.info("=" * 50)
|
||||
|
||||
@@ -72,6 +72,10 @@ class HomographyTrainer:
|
||||
self.save_checkpoint(epoch, is_best=True)
|
||||
print(f"Best model saved (val loss: {val_metrics['loss']:.4f})")
|
||||
|
||||
if epoch % config["save_every_n_epochs"] == 0:
|
||||
self.save_checkpoint(epoch, is_best=False)
|
||||
print(f"Checkpoint saved at epoch {epoch}")
|
||||
|
||||
self.writer.close()
|
||||
|
||||
def save_checkpoint(self, epoch, is_best=False):
|
||||
|
||||
@@ -12,6 +12,7 @@ config = {
|
||||
"dropout_rate": 0.3,
|
||||
"backbone": "resnet18",
|
||||
"output_dir": r"C:\Users\admin\Projects\autopilot\models\SiaN\runs",
|
||||
"save_every_n_epochs": 15,
|
||||
}
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user