feat: add GAN
This commit is contained in:
121
models/GAN/_schema.py
Normal file
121
models/GAN/_schema.py
Normal file
@@ -0,0 +1,121 @@
|
||||
# _schema.py
|
||||
|
||||
# === IMPORTS ===
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from PIL import Image
|
||||
from torch.utils.data import DataLoader, Dataset, Subset
|
||||
from torchvision import transforms
|
||||
from tqdm import tqdm
|
||||
|
||||
# markdown
|
||||
"""# Configuration
|
||||
|
||||
Global settings for the Google -> Yandex GAN:
|
||||
- Dataset path and image size
|
||||
- Optimizer and training hyperparameters
|
||||
- Device preference with safe CUDA compatibility checks
|
||||
- GAN, L1, SSIM and edge reconstruction weights
|
||||
- Output directories for checkpoints and generated samples"""
|
||||
# code: ./src/config.py
|
||||
|
||||
# markdown
|
||||
"""## Dataset
|
||||
|
||||
Google/Yandex paired image loader.
|
||||
|
||||
**Direction:**
|
||||
- `google_img` is the generator input
|
||||
- `yandex_img` is the target image from the same pair
|
||||
|
||||
**Returns:**
|
||||
- Batch dict with `google_img`, `yandex_img`, `idx`"""
|
||||
# code: ./src/dataloader.py
|
||||
|
||||
# code: ./src/test_dataloader.py
|
||||
|
||||
# markdown
|
||||
"""## Model
|
||||
|
||||
Pix2pix-style GAN for Google -> Yandex map translation.
|
||||
|
||||
**Generator:**
|
||||
- `GeneratorUNet`
|
||||
- Input: Google image `(B, 3, H, W)`
|
||||
- Output: generated Yandex image `(B, 3, H, W)`
|
||||
|
||||
**Discriminator:**
|
||||
- `DiscriminatorPatchGAN`
|
||||
- Input pair: `(google_img, yandex_img)`
|
||||
- Learns to distinguish real pairs from `(google_img, fake_yandex)`
|
||||
|
||||
**Generator loss:**
|
||||
- adversarial loss
|
||||
- `lambda_L1 * L1(fake_yandex, yandex_img)`
|
||||
- `lambda_SSIM * SSIMLoss(fake_yandex, yandex_img)`
|
||||
- `lambda_edge * SobelEdgeLoss(fake_yandex, yandex_img)`
|
||||
|
||||
The generator uses bilinear upsampling followed by convolution to avoid
|
||||
checkerboard artifacts from transposed convolutions."""
|
||||
# code: ./src/model.py
|
||||
|
||||
# markdown
|
||||
"""## Training
|
||||
|
||||
`GANTrainer` trains discriminator and generator alternately.
|
||||
|
||||
**Training step:**
|
||||
1. Generate `fake_yandex = G(google_img)`
|
||||
2. Train discriminator on real pair `(google_img, yandex_img)` and fake pair `(google_img, fake_yandex)`
|
||||
3. Train generator against discriminator and paired Yandex target
|
||||
|
||||
**Checkpoint saving:**
|
||||
- `best.pth`
|
||||
- `epoch_N.pth`
|
||||
- `final.pth`"""
|
||||
# code: ./src/trainer.py
|
||||
|
||||
# markdown
|
||||
"""## Analysis
|
||||
|
||||
Visualization helpers for generated samples and collected training metrics.
|
||||
|
||||
Training history plot contains:
|
||||
1. Generator loss
|
||||
2. Discriminator loss
|
||||
3. L1 loss against the paired Yandex target
|
||||
4. SSIM structure loss
|
||||
5. Sobel edge loss
|
||||
6. Best-checkpoint reconstruction score
|
||||
|
||||
The sample grid contains:
|
||||
1. Google input
|
||||
2. Generated Yandex
|
||||
3. Original Yandex target"""
|
||||
# code: ./src/analyze.py
|
||||
|
||||
# markdown
|
||||
"""## Main Pipeline
|
||||
|
||||
Executes the full GAN workflow:
|
||||
1. Create config
|
||||
2. Build paired data loaders
|
||||
3. Initialize Google -> Yandex GAN
|
||||
4. Train with validation
|
||||
5. Save checkpoints in `runs/checkpoints/`
|
||||
6. Show loss plots and generated sample grid
|
||||
|
||||
This block is intentionally top-level, not wrapped in `main()`, so notebook
|
||||
variables such as `model`, `trainer`, `train_loader`, `val_loader`, and
|
||||
`training_analysis` remain available for debugging."""
|
||||
# code: ./src/main.py
|
||||
|
||||
# # shell:
|
||||
# !zip artefacts.zip runs/checkpoints/best.pth runs/images/training_history.png runs/images/generation_samples.png
|
||||
Reference in New Issue
Block a user