feat: add GAN

This commit is contained in:
2026-05-30 14:49:40 +03:00
parent 6477ce0776
commit 72e1950127
29 changed files with 2670 additions and 361 deletions

121
models/GAN/_schema.py Normal file
View File

@@ -0,0 +1,121 @@
# _schema.py
# === IMPORTS ===
import json
import os
from pathlib import Path
from typing import Any, Dict, List, Tuple
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from torch.utils.data import DataLoader, Dataset, Subset
from torchvision import transforms
from tqdm import tqdm
# markdown
"""# Configuration
Global settings for the Google -> Yandex GAN:
- Dataset path and image size
- Optimizer and training hyperparameters
- Device preference with safe CUDA compatibility checks
- GAN, L1, SSIM and edge reconstruction weights
- Output directories for checkpoints and generated samples"""
# code: ./src/config.py
# markdown
"""## Dataset
Google/Yandex paired image loader.
**Direction:**
- `google_img` is the generator input
- `yandex_img` is the target image from the same pair
**Returns:**
- Batch dict with `google_img`, `yandex_img`, `idx`"""
# code: ./src/dataloader.py
# code: ./src/test_dataloader.py
# markdown
"""## Model
Pix2pix-style GAN for Google -> Yandex map translation.
**Generator:**
- `GeneratorUNet`
- Input: Google image `(B, 3, H, W)`
- Output: generated Yandex image `(B, 3, H, W)`
**Discriminator:**
- `DiscriminatorPatchGAN`
- Input pair: `(google_img, yandex_img)`
- Learns to distinguish real pairs from `(google_img, fake_yandex)`
**Generator loss:**
- adversarial loss
- `lambda_L1 * L1(fake_yandex, yandex_img)`
- `lambda_SSIM * SSIMLoss(fake_yandex, yandex_img)`
- `lambda_edge * SobelEdgeLoss(fake_yandex, yandex_img)`
The generator uses bilinear upsampling followed by convolution to avoid
checkerboard artifacts from transposed convolutions."""
# code: ./src/model.py
# markdown
"""## Training
`GANTrainer` trains discriminator and generator alternately.
**Training step:**
1. Generate `fake_yandex = G(google_img)`
2. Train discriminator on real pair `(google_img, yandex_img)` and fake pair `(google_img, fake_yandex)`
3. Train generator against discriminator and paired Yandex target
**Checkpoint saving:**
- `best.pth`
- `epoch_N.pth`
- `final.pth`"""
# code: ./src/trainer.py
# markdown
"""## Analysis
Visualization helpers for generated samples and collected training metrics.
Training history plot contains:
1. Generator loss
2. Discriminator loss
3. L1 loss against the paired Yandex target
4. SSIM structure loss
5. Sobel edge loss
6. Best-checkpoint reconstruction score
The sample grid contains:
1. Google input
2. Generated Yandex
3. Original Yandex target"""
# code: ./src/analyze.py
# markdown
"""## Main Pipeline
Executes the full GAN workflow:
1. Create config
2. Build paired data loaders
3. Initialize Google -> Yandex GAN
4. Train with validation
5. Save checkpoints in `runs/checkpoints/`
6. Show loss plots and generated sample grid
This block is intentionally top-level, not wrapped in `main()`, so notebook
variables such as `model`, `trainer`, `train_loader`, `val_loader`, and
`training_analysis` remain available for debugging."""
# code: ./src/main.py
# # shell:
# !zip artefacts.zip runs/checkpoints/best.pth runs/images/training_history.png runs/images/generation_samples.png