122 lines
3.1 KiB
Python
122 lines
3.1 KiB
Python
# _schema.py
|
|
|
|
# === IMPORTS ===
|
|
import json
|
|
import os
|
|
from pathlib import Path
|
|
from typing import Any, Dict, List, Tuple
|
|
|
|
import matplotlib.pyplot as plt
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from PIL import Image
|
|
from torch.utils.data import DataLoader, Dataset, Subset
|
|
from torchvision import transforms
|
|
from tqdm import tqdm
|
|
|
|
# markdown
|
|
"""# Configuration
|
|
|
|
Global settings for the Google -> Yandex GAN:
|
|
- Dataset path and image size
|
|
- Optimizer and training hyperparameters
|
|
- Device preference with safe CUDA compatibility checks
|
|
- GAN, L1, SSIM and edge reconstruction weights
|
|
- Output directories for checkpoints and generated samples"""
|
|
# code: ./src/config.py
|
|
|
|
# markdown
|
|
"""## Dataset
|
|
|
|
Google/Yandex paired image loader.
|
|
|
|
**Direction:**
|
|
- `google_img` is the generator input
|
|
- `yandex_img` is the target image from the same pair
|
|
|
|
**Returns:**
|
|
- Batch dict with `google_img`, `yandex_img`, `idx`"""
|
|
# code: ./src/dataloader.py
|
|
|
|
# code: ./src/test_dataloader.py
|
|
|
|
# markdown
|
|
"""## Model
|
|
|
|
Pix2pix-style GAN for Google -> Yandex map translation.
|
|
|
|
**Generator:**
|
|
- `GeneratorUNet`
|
|
- Input: Google image `(B, 3, H, W)`
|
|
- Output: generated Yandex image `(B, 3, H, W)`
|
|
|
|
**Discriminator:**
|
|
- `DiscriminatorPatchGAN`
|
|
- Input pair: `(google_img, yandex_img)`
|
|
- Learns to distinguish real pairs from `(google_img, fake_yandex)`
|
|
|
|
**Generator loss:**
|
|
- adversarial loss
|
|
- `lambda_L1 * L1(fake_yandex, yandex_img)`
|
|
- `lambda_SSIM * SSIMLoss(fake_yandex, yandex_img)`
|
|
- `lambda_edge * SobelEdgeLoss(fake_yandex, yandex_img)`
|
|
|
|
The generator uses bilinear upsampling followed by convolution to avoid
|
|
checkerboard artifacts from transposed convolutions."""
|
|
# code: ./src/model.py
|
|
|
|
# markdown
|
|
"""## Training
|
|
|
|
`GANTrainer` trains discriminator and generator alternately.
|
|
|
|
**Training step:**
|
|
1. Generate `fake_yandex = G(google_img)`
|
|
2. Train discriminator on real pair `(google_img, yandex_img)` and fake pair `(google_img, fake_yandex)`
|
|
3. Train generator against discriminator and paired Yandex target
|
|
|
|
**Checkpoint saving:**
|
|
- `best.pth`
|
|
- `epoch_N.pth`
|
|
- `final.pth`"""
|
|
# code: ./src/trainer.py
|
|
|
|
# markdown
|
|
"""## Analysis
|
|
|
|
Visualization helpers for generated samples and collected training metrics.
|
|
|
|
Training history plot contains:
|
|
1. Generator loss
|
|
2. Discriminator loss
|
|
3. L1 loss against the paired Yandex target
|
|
4. SSIM structure loss
|
|
5. Sobel edge loss
|
|
6. Best-checkpoint reconstruction score
|
|
|
|
The sample grid contains:
|
|
1. Google input
|
|
2. Generated Yandex
|
|
3. Original Yandex target"""
|
|
# code: ./src/analyze.py
|
|
|
|
# markdown
|
|
"""## Main Pipeline
|
|
|
|
Executes the full GAN workflow:
|
|
1. Create config
|
|
2. Build paired data loaders
|
|
3. Initialize Google -> Yandex GAN
|
|
4. Train with validation
|
|
5. Save checkpoints in `runs/checkpoints/`
|
|
6. Show loss plots and generated sample grid
|
|
|
|
This block is intentionally top-level, not wrapped in `main()`, so notebook
|
|
variables such as `model`, `trainer`, `train_loader`, `val_loader`, and
|
|
`training_analysis` remain available for debugging."""
|
|
# code: ./src/main.py
|
|
|
|
# # shell:
|
|
# !zip artefacts.zip runs/checkpoints/best.pth runs/images/training_history.png runs/images/generation_samples.png
|