1064 lines
46 KiB
Plaintext
1064 lines
46 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import json\n",
|
|
"import os\n",
|
|
"from pathlib import Path\n",
|
|
"from typing import Any, Dict, List, Tuple\n",
|
|
"import matplotlib.pyplot as plt\n",
|
|
"import torch\n",
|
|
"import torch.nn as nn\n",
|
|
"import torch.nn.functional as F\n",
|
|
"from PIL import Image\n",
|
|
"from torch.utils.data import DataLoader, Dataset, Subset\n",
|
|
"from torchvision import transforms\n",
|
|
"from tqdm import tqdm\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": "# Configuration\n\nGlobal settings for the Google -> Yandex GAN:\n- Dataset path and image size\n- Optimizer and training hyperparameters\n- Device preference with safe CUDA compatibility checks\n- GAN, L1, SSIM and edge reconstruction weights\n"
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"\"\"\"Configuration for GAN training.\"\"\"\n",
|
|
"\n",
|
|
"\n",
|
|
"def create_config():\n",
|
|
" \"\"\"Create default configuration dictionary.\"\"\"\n",
|
|
" return {\n",
|
|
" # Optimizer params\n",
|
|
" \"learning_rate\": 2e-4,\n",
|
|
" \"discriminator_lr_factor\": 0.5,\n",
|
|
" \"beta1\": 0.5,\n",
|
|
" \"beta2\": 0.999,\n",
|
|
" # Training params\n",
|
|
" \"batch_size\": 32,\n",
|
|
" \"epochs\": 100,\n",
|
|
" \"prefer_cuda\": True,\n",
|
|
" # GAN params\n",
|
|
" \"gan_mode\": \"lsgan\",\n",
|
|
" \"lambda_GAN\": 0.5,\n",
|
|
" \"lambda_L1\": 150.0,\n",
|
|
" \"lambda_SSIM\": 25.0,\n",
|
|
" \"lambda_edge\": 20.0,\n",
|
|
" \"discriminator_update_interval\": 1,\n",
|
|
" # Regularization\n",
|
|
" \"grad_clip\": 1.0,\n",
|
|
" # Early stopping\n",
|
|
" \"early_stopping_patience\": 25,\n",
|
|
" # Output\n",
|
|
" \"output_dir\": r\"C:\\Users\\admin\\Projects\\autopilot\\models\\GAN\\runs\",\n",
|
|
" # Logging\n",
|
|
" \"log_interval\": 10,\n",
|
|
" \"save_interval\": 5,\n",
|
|
" \"num_visual_samples\": 4,\n",
|
|
" # Data\n",
|
|
" \"data_dir\": r\"C:\\Users\\admin\\Projects\\autopilot\\datasets\\ya_go_maps\\images\",\n",
|
|
" \"image_size\": [256, 256],\n",
|
|
" \"train_split\": 0.8,\n",
|
|
" \"num_workers\": 0,\n",
|
|
" }\n",
|
|
"\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": "## Dataset\n\nGoogle/Yandex paired image loader.\n\n**Direction:**\n- `google_img` is the generator input\n- `yandex_img` is the target image from the same pair\n\n**Returns:**\n"
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"\"\"\"Data loader for Google-to-Yandex image translation.\"\"\"\n",
|
|
"\n",
|
|
"\n",
|
|
"\n",
|
|
"\n",
|
|
"class YaGoDataset(Dataset):\n",
|
|
" \"\"\"Dataset loading paired Google and Yandex map images.\"\"\"\n",
|
|
"\n",
|
|
" def __init__(\n",
|
|
" self,\n",
|
|
" root_dir: str,\n",
|
|
" image_size: Tuple[int, int] = (256, 256),\n",
|
|
" augment: bool = False,\n",
|
|
" ):\n",
|
|
" \"\"\"\n",
|
|
" Args:\n",
|
|
" root_dir: Directory with images named {idx:04d}_google.png and {idx:04d}_yandex.png\n",
|
|
" image_size: Target image size (height, width)\n",
|
|
" augment: Whether to apply augmentation (not implemented for simplicity)\n",
|
|
" \"\"\"\n",
|
|
" self.root_dir = root_dir\n",
|
|
" self.image_size = image_size\n",
|
|
" self.augment = augment\n",
|
|
"\n",
|
|
" # Discover image pairs\n",
|
|
" self.pairs = self._find_pairs()\n",
|
|
"\n",
|
|
" # Transform to tensor + normalization\n",
|
|
" self.transform = transforms.Compose(\n",
|
|
" [\n",
|
|
" transforms.ToTensor(),\n",
|
|
" transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),\n",
|
|
" ]\n",
|
|
" )\n",
|
|
"\n",
|
|
" def _find_pairs(self) -> List[Dict]:\n",
|
|
" \"\"\"Find all matching Google-Yandex image pairs.\"\"\"\n",
|
|
" pairs = []\n",
|
|
" google_files = [f for f in os.listdir(self.root_dir) if f.endswith(\"_google.png\")]\n",
|
|
"\n",
|
|
" for google_file in sorted(google_files):\n",
|
|
" idx_str = google_file.split(\"_\")[0]\n",
|
|
" try:\n",
|
|
" idx = int(idx_str)\n",
|
|
" except ValueError:\n",
|
|
" continue\n",
|
|
"\n",
|
|
" yandex_file = f\"{idx:04d}_yandex.png\"\n",
|
|
" yandex_path = os.path.join(self.root_dir, yandex_file)\n",
|
|
"\n",
|
|
" if os.path.exists(yandex_path):\n",
|
|
" pairs.append(\n",
|
|
" {\n",
|
|
" \"idx\": idx,\n",
|
|
" \"google_path\": os.path.join(self.root_dir, google_file),\n",
|
|
" \"yandex_path\": yandex_path,\n",
|
|
" }\n",
|
|
" )\n",
|
|
"\n",
|
|
" return pairs\n",
|
|
"\n",
|
|
" def __len__(self) -> int:\n",
|
|
" return len(self.pairs)\n",
|
|
"\n",
|
|
" def __getitem__(self, idx: int) -> dict:\n",
|
|
" pair = self.pairs[idx]\n",
|
|
"\n",
|
|
" # Load images\n",
|
|
" google_img = Image.open(pair[\"google_path\"]).convert(\"RGB\")\n",
|
|
" yandex_img = Image.open(pair[\"yandex_path\"]).convert(\"RGB\")\n",
|
|
"\n",
|
|
" # Resize\n",
|
|
" google_img = google_img.resize((self.image_size[1], self.image_size[0]))\n",
|
|
" yandex_img = yandex_img.resize((self.image_size[1], self.image_size[0]))\n",
|
|
"\n",
|
|
" # Apply transforms\n",
|
|
" google_tensor = self.transform(google_img)\n",
|
|
" yandex_tensor = self.transform(yandex_img)\n",
|
|
"\n",
|
|
" return {\n",
|
|
" \"google_img\": google_tensor,\n",
|
|
" \"yandex_img\": yandex_tensor,\n",
|
|
" \"idx\": torch.tensor(pair[\"idx\"], dtype=torch.long),\n",
|
|
" }\n",
|
|
"\n",
|
|
"\n",
|
|
"def create_data_loaders(\n",
|
|
" root_dir: str,\n",
|
|
" batch_size: int = 32,\n",
|
|
" train_split: float = 0.8,\n",
|
|
" num_workers: int = 0,\n",
|
|
" image_size: Tuple[int, int] = (256, 256),\n",
|
|
") -> Tuple[DataLoader, DataLoader]:\n",
|
|
" \"\"\"\n",
|
|
" Create train and validation data loaders.\n",
|
|
"\n",
|
|
" Args:\n",
|
|
" root_dir: Directory with image pairs\n",
|
|
" batch_size: Batch size\n",
|
|
" train_split: Fraction for training (0.0-1.0)\n",
|
|
" num_workers: DataLoader workers\n",
|
|
" image_size: Target image size\n",
|
|
"\n",
|
|
" Returns:\n",
|
|
" (train_loader, val_loader)\n",
|
|
" \"\"\"\n",
|
|
" # Full dataset\n",
|
|
" dataset = YaGoDataset(root_dir=root_dir, image_size=image_size)\n",
|
|
"\n",
|
|
" # Split\n",
|
|
" dataset_size = len(dataset)\n",
|
|
" train_size = int(train_split * dataset_size)\n",
|
|
" indices = torch.randperm(dataset_size).tolist()\n",
|
|
" train_indices = indices[:train_size]\n",
|
|
" val_indices = indices[train_size:]\n",
|
|
"\n",
|
|
" # Subsets\n",
|
|
"\n",
|
|
" train_dataset = Subset(dataset, train_indices)\n",
|
|
" val_dataset = Subset(dataset, val_indices)\n",
|
|
"\n",
|
|
" # DataLoaders\n",
|
|
" train_loader = DataLoader(\n",
|
|
" train_dataset,\n",
|
|
" batch_size=batch_size,\n",
|
|
" shuffle=True,\n",
|
|
" num_workers=num_workers,\n",
|
|
" pin_memory=True,\n",
|
|
" )\n",
|
|
"\n",
|
|
" val_loader = DataLoader(\n",
|
|
" val_dataset,\n",
|
|
" batch_size=batch_size,\n",
|
|
" shuffle=False,\n",
|
|
" num_workers=num_workers,\n",
|
|
" pin_memory=True,\n",
|
|
" )\n",
|
|
"\n",
|
|
" return train_loader, val_loader\n",
|
|
"\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"\n",
|
|
"\n",
|
|
"def get_dataset_info():\n",
|
|
" config = create_config()\n",
|
|
" dataset = YaGoDataset(\n",
|
|
" root_dir=config[\"data_dir\"],\n",
|
|
" image_size=tuple(config[\"image_size\"]),\n",
|
|
" )\n",
|
|
" sample = dataset[0] if len(dataset) else {}\n",
|
|
" return {\n",
|
|
" \"size\": len(dataset),\n",
|
|
" \"sample_keys\": list(sample.keys()),\n",
|
|
" \"google_shape\": tuple(sample[\"google_img\"].shape) if sample else None,\n",
|
|
" \"yandex_shape\": tuple(sample[\"yandex_img\"].shape) if sample else None,\n",
|
|
" }\n",
|
|
"\n",
|
|
"\n",
|
|
"def smoke_test_dataloader(batch_size=4):\n",
|
|
" config = create_config()\n",
|
|
" train_loader, val_loader = create_data_loaders(\n",
|
|
" root_dir=config[\"data_dir\"],\n",
|
|
" batch_size=batch_size,\n",
|
|
" train_split=config[\"train_split\"],\n",
|
|
" num_workers=config[\"num_workers\"],\n",
|
|
" image_size=tuple(config[\"image_size\"]),\n",
|
|
" )\n",
|
|
" batch = next(iter(train_loader))\n",
|
|
" return {\n",
|
|
" \"train_size\": len(train_loader.dataset),\n",
|
|
" \"val_size\": len(val_loader.dataset),\n",
|
|
" \"google_batch_shape\": tuple(batch[\"google_img\"].shape),\n",
|
|
" \"yandex_batch_shape\": tuple(batch[\"yandex_img\"].shape),\n",
|
|
" }\n",
|
|
"\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": "## Model\n\nPix2pix-style GAN for Google -> Yandex map translation.\n\n**Generator:**\n- `GeneratorUNet`\n- Input: Google image `(B, 3, H, W)`\n- Output: generated Yandex image `(B, 3, H, W)`\n\n**Discriminator:**\n- `DiscriminatorPatchGAN`\n- Input pair: `(google_img, yandex_img)`\n- Learns to distinguish real pairs from `(google_img, fake_yandex)`\n\n**Generator loss:**\n- adversarial loss\n- `lambda_L1 * L1(fake_yandex, yandex_img)`\n- `lambda_SSIM * SSIMLoss(fake_yandex, yandex_img)`\n- `lambda_edge * SobelEdgeLoss(fake_yandex, yandex_img)`\n\nThe generator uses bilinear upsampling followed by convolution to avoid\n"
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"\"\"\"GAN model for image translation Google -> Yandex.\"\"\"\n",
|
|
"\n",
|
|
"\n",
|
|
"\n",
|
|
"def get_compatible_device(prefer_cuda: bool = True, verbose: bool = True) -> torch.device:\n",
|
|
" \"\"\"Return CUDA only when the current PyTorch build supports the GPU arch.\"\"\"\n",
|
|
" if not prefer_cuda or not torch.cuda.is_available():\n",
|
|
" return torch.device(\"cpu\")\n",
|
|
"\n",
|
|
" try:\n",
|
|
" major, minor = torch.cuda.get_device_capability()\n",
|
|
" arch = f\"sm_{major}{minor}\"\n",
|
|
" supported_arches = set(torch.cuda.get_arch_list())\n",
|
|
" gpu_name = torch.cuda.get_device_name()\n",
|
|
" except Exception as exc:\n",
|
|
" if verbose:\n",
|
|
" print(f\"CUDA is visible but cannot be inspected ({exc}); using CPU.\")\n",
|
|
" return torch.device(\"cpu\")\n",
|
|
"\n",
|
|
" if supported_arches and arch not in supported_arches:\n",
|
|
" if verbose:\n",
|
|
" supported = \", \".join(sorted(supported_arches))\n",
|
|
" print(\n",
|
|
" f\"CUDA GPU '{gpu_name}' has capability {arch}, but this PyTorch build \"\n",
|
|
" f\"supports only: {supported}. Using CPU.\"\n",
|
|
" )\n",
|
|
" return torch.device(\"cpu\")\n",
|
|
"\n",
|
|
" return torch.device(\"cuda\")\n",
|
|
"\n",
|
|
"\n",
|
|
"class UNetDownBlock(nn.Module):\n",
|
|
" \"\"\"Downsampling block for U-Net.\"\"\"\n",
|
|
"\n",
|
|
" def __init__(self, in_channels: int, out_channels: int, normalize: bool = True, dropout: float = 0.0):\n",
|
|
" super().__init__()\n",
|
|
" layers = [\n",
|
|
" nn.Conv2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1, bias=False)\n",
|
|
" ]\n",
|
|
" if normalize:\n",
|
|
" layers.append(nn.BatchNorm2d(out_channels))\n",
|
|
" layers.append(nn.LeakyReLU(0.2, inplace=True))\n",
|
|
" if dropout > 0:\n",
|
|
" layers.append(nn.Dropout2d(dropout))\n",
|
|
" self.model = nn.Sequential(*layers)\n",
|
|
"\n",
|
|
" def forward(self, x):\n",
|
|
" return self.model(x)\n",
|
|
"\n",
|
|
"\n",
|
|
"class UNetUpBlock(nn.Module):\n",
|
|
" \"\"\"Upsampling block for U-Net.\"\"\"\n",
|
|
"\n",
|
|
" def __init__(self, in_channels: int, out_channels: int, dropout: float = 0.0):\n",
|
|
" super().__init__()\n",
|
|
" self.upconv = nn.Sequential(\n",
|
|
" nn.Upsample(scale_factor=2, mode=\"bilinear\", align_corners=False),\n",
|
|
" nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),\n",
|
|
" )\n",
|
|
" self.norm = nn.BatchNorm2d(out_channels)\n",
|
|
" self.relu = nn.ReLU(inplace=True)\n",
|
|
" if dropout > 0:\n",
|
|
" self.dropout = nn.Dropout2d(dropout)\n",
|
|
" else:\n",
|
|
" self.dropout = None\n",
|
|
"\n",
|
|
" def forward(self, x, skip_input):\n",
|
|
" x = self.upconv(x)\n",
|
|
" # Pad if needed to match skip connection size\n",
|
|
" if x.shape != skip_input.shape:\n",
|
|
" diff_h = skip_input.size(2) - x.size(2)\n",
|
|
" diff_w = skip_input.size(3) - x.size(3)\n",
|
|
" x = F.pad(x, [diff_w // 2, diff_w - diff_w // 2, diff_h // 2, diff_h - diff_h // 2])\n",
|
|
" x = self.norm(x)\n",
|
|
" x = self.relu(x)\n",
|
|
" if self.dropout:\n",
|
|
" x = self.dropout(x)\n",
|
|
" x = torch.cat([x, skip_input], dim=1)\n",
|
|
" return x\n",
|
|
"\n",
|
|
"\n",
|
|
"class GeneratorUNet(nn.Module):\n",
|
|
" \"\"\"U-Net generator for Google -> Yandex translation.\"\"\"\n",
|
|
"\n",
|
|
" def __init__(self, in_channels: int = 3, out_channels: int = 3):\n",
|
|
" super().__init__()\n",
|
|
"\n",
|
|
" # Downsampling\n",
|
|
" self.down1 = UNetDownBlock(in_channels, 64, normalize=False)\n",
|
|
" self.down2 = UNetDownBlock(64, 128)\n",
|
|
" self.down3 = UNetDownBlock(128, 256)\n",
|
|
" self.down4 = UNetDownBlock(256, 512)\n",
|
|
" self.down5 = UNetDownBlock(512, 512)\n",
|
|
" self.down6 = UNetDownBlock(512, 512)\n",
|
|
" self.down7 = UNetDownBlock(512, 512)\n",
|
|
"\n",
|
|
" # Bottleneck\n",
|
|
" self.bottleneck = nn.Sequential(\n",
|
|
" nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1, bias=False),\n",
|
|
" nn.ReLU(inplace=True),\n",
|
|
" )\n",
|
|
"\n",
|
|
" # Upsampling - input channels from previous layer, output before concat\n",
|
|
" self.up1 = UNetUpBlock(512, 512, dropout=0.5) # in: 512 (bottleneck) -> out: 512, concat with d7 (512) = 1024\n",
|
|
" self.up2 = UNetUpBlock(1024, 512, dropout=0.5) # in: 1024 -> out: 512, concat with d6 (512) = 1024\n",
|
|
" self.up3 = UNetUpBlock(1024, 512, dropout=0.5) # in: 1024 -> out: 512, concat with d5 (512) = 1024\n",
|
|
" self.up4 = UNetUpBlock(1024, 512) # in: 1024 -> out: 512, concat with d4 (512) = 1024\n",
|
|
" self.up5 = UNetUpBlock(1024, 256) # in: 1024 -> out: 256, concat with d3 (256) = 512\n",
|
|
" self.up6 = UNetUpBlock(512, 128) # in: 512 -> out: 128, concat with d2 (128) = 256\n",
|
|
" self.up7 = UNetUpBlock(256, 64) # in: 256 -> out: 64, concat with d1 (64) = 128\n",
|
|
"\n",
|
|
" # Final\n",
|
|
" self.final = nn.Sequential(\n",
|
|
" nn.Upsample(scale_factor=2, mode=\"bilinear\", align_corners=False),\n",
|
|
" nn.Conv2d(128, out_channels, kernel_size=3, padding=1),\n",
|
|
" nn.Tanh(),\n",
|
|
" )\n",
|
|
"\n",
|
|
" def forward(self, x):\n",
|
|
" # Down\n",
|
|
" d1 = self.down1(x)\n",
|
|
" d2 = self.down2(d1)\n",
|
|
" d3 = self.down3(d2)\n",
|
|
" d4 = self.down4(d3)\n",
|
|
" d5 = self.down5(d4)\n",
|
|
" d6 = self.down6(d5)\n",
|
|
" d7 = self.down7(d6)\n",
|
|
"\n",
|
|
" # Bottleneck\n",
|
|
" u = self.bottleneck(d7)\n",
|
|
"\n",
|
|
" # Up with skip connections\n",
|
|
" u = self.up1(u, d7)\n",
|
|
" u = self.up2(u, d6)\n",
|
|
" u = self.up3(u, d5)\n",
|
|
" u = self.up4(u, d4)\n",
|
|
" u = self.up5(u, d3)\n",
|
|
" u = self.up6(u, d2)\n",
|
|
" u = self.up7(u, d1)\n",
|
|
"\n",
|
|
" return self.final(u)\n",
|
|
"\n",
|
|
"\n",
|
|
"class DiscriminatorPatchGAN(nn.Module):\n",
|
|
" \"\"\"PatchGAN discriminator for paired source/target images.\"\"\"\n",
|
|
"\n",
|
|
" def __init__(self, in_channels: int = 6):\n",
|
|
" super().__init__()\n",
|
|
" self.model = nn.Sequential(\n",
|
|
" nn.Conv2d(in_channels, 64, kernel_size=4, stride=2, padding=1),\n",
|
|
" nn.LeakyReLU(0.2, inplace=True),\n",
|
|
" nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),\n",
|
|
" nn.BatchNorm2d(128),\n",
|
|
" nn.LeakyReLU(0.2, inplace=True),\n",
|
|
" nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),\n",
|
|
" nn.BatchNorm2d(256),\n",
|
|
" nn.LeakyReLU(0.2, inplace=True),\n",
|
|
" nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1),\n",
|
|
" nn.BatchNorm2d(512),\n",
|
|
" nn.LeakyReLU(0.2, inplace=True),\n",
|
|
" nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=1),\n",
|
|
" )\n",
|
|
"\n",
|
|
" def forward(self, img_A, img_B):\n",
|
|
" x = torch.cat([img_A, img_B], dim=1)\n",
|
|
" return self.model(x)\n",
|
|
"\n",
|
|
"\n",
|
|
"class GANLoss(nn.Module):\n",
|
|
" \"\"\"GAN loss supporting different GAN modes.\"\"\"\n",
|
|
"\n",
|
|
" def __init__(self, gan_mode: str = \"vanilla\", target_real: float = 1.0, target_fake: float = 0.0):\n",
|
|
" super().__init__()\n",
|
|
" self.gan_mode = gan_mode\n",
|
|
" self.register_buffer(\"real_label\", torch.tensor(target_real))\n",
|
|
" self.register_buffer(\"fake_label\", torch.tensor(target_fake))\n",
|
|
"\n",
|
|
" if gan_mode == \"vanilla\":\n",
|
|
" self.loss_fn = nn.BCEWithLogitsLoss()\n",
|
|
" elif gan_mode == \"lsgan\":\n",
|
|
" self.loss_fn = nn.MSELoss()\n",
|
|
" elif gan_mode == \"wgangp\":\n",
|
|
" self.loss_fn = None\n",
|
|
" else:\n",
|
|
" raise ValueError(f\"Unknown GAN mode: {gan_mode}\")\n",
|
|
"\n",
|
|
" def forward(self, prediction: torch.Tensor, target_is_real: bool) -> torch.Tensor:\n",
|
|
" if self.gan_mode in [\"vanilla\", \"lsgan\"]:\n",
|
|
" target = self.real_label if target_is_real else self.fake_label\n",
|
|
" target = target.expand_as(prediction)\n",
|
|
" return self.loss_fn(prediction, target)\n",
|
|
" elif self.gan_mode == \"wgangp\":\n",
|
|
" return -prediction.mean() if target_is_real else prediction.mean()\n",
|
|
"\n",
|
|
"\n",
|
|
"class SSIMLoss(nn.Module):\n",
|
|
" \"\"\"Local SSIM loss for normalized image tensors in [-1, 1].\"\"\"\n",
|
|
"\n",
|
|
" def __init__(self, window_size: int = 11, c1: float = 0.01 ** 2, c2: float = 0.03 ** 2):\n",
|
|
" super().__init__()\n",
|
|
" self.window_size = window_size\n",
|
|
" self.c1 = c1\n",
|
|
" self.c2 = c2\n",
|
|
"\n",
|
|
" def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:\n",
|
|
" pred = (pred + 1.0) * 0.5\n",
|
|
" target = (target + 1.0) * 0.5\n",
|
|
" padding = self.window_size // 2\n",
|
|
"\n",
|
|
" mu_pred = F.avg_pool2d(pred, self.window_size, stride=1, padding=padding)\n",
|
|
" mu_target = F.avg_pool2d(target, self.window_size, stride=1, padding=padding)\n",
|
|
" mu_pred_sq = mu_pred.pow(2)\n",
|
|
" mu_target_sq = mu_target.pow(2)\n",
|
|
" mu_pred_target = mu_pred * mu_target\n",
|
|
"\n",
|
|
" sigma_pred = F.avg_pool2d(pred * pred, self.window_size, stride=1, padding=padding) - mu_pred_sq\n",
|
|
" sigma_target = F.avg_pool2d(target * target, self.window_size, stride=1, padding=padding) - mu_target_sq\n",
|
|
" sigma_pred_target = F.avg_pool2d(pred * target, self.window_size, stride=1, padding=padding) - mu_pred_target\n",
|
|
"\n",
|
|
" ssim_map = (\n",
|
|
" (2 * mu_pred_target + self.c1) * (2 * sigma_pred_target + self.c2)\n",
|
|
" ) / (\n",
|
|
" (mu_pred_sq + mu_target_sq + self.c1) * (sigma_pred + sigma_target + self.c2)\n",
|
|
" )\n",
|
|
" return (1.0 - ssim_map.clamp(0, 1)).mean()\n",
|
|
"\n",
|
|
"\n",
|
|
"class SobelEdgeLoss(nn.Module):\n",
|
|
" \"\"\"L1 loss between Sobel edge maps, useful for stable keypoint structure.\"\"\"\n",
|
|
"\n",
|
|
" def __init__(self):\n",
|
|
" super().__init__()\n",
|
|
" kernel_x = torch.tensor(\n",
|
|
" [[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]],\n",
|
|
" dtype=torch.float32,\n",
|
|
" ).view(1, 1, 3, 3)\n",
|
|
" kernel_y = torch.tensor(\n",
|
|
" [[-1, -2, -1], [0, 0, 0], [1, 2, 1]],\n",
|
|
" dtype=torch.float32,\n",
|
|
" ).view(1, 1, 3, 3)\n",
|
|
" self.register_buffer(\"kernel_x\", kernel_x)\n",
|
|
" self.register_buffer(\"kernel_y\", kernel_y)\n",
|
|
"\n",
|
|
" @staticmethod\n",
|
|
" def _to_gray(x: torch.Tensor) -> torch.Tensor:\n",
|
|
" x = (x + 1.0) * 0.5\n",
|
|
" weights = x.new_tensor([0.299, 0.587, 0.114]).view(1, 3, 1, 1)\n",
|
|
" return (x * weights).sum(dim=1, keepdim=True)\n",
|
|
"\n",
|
|
" def _edges(self, x: torch.Tensor) -> torch.Tensor:\n",
|
|
" gray = self._to_gray(x)\n",
|
|
" grad_x = F.conv2d(gray, self.kernel_x, padding=1)\n",
|
|
" grad_y = F.conv2d(gray, self.kernel_y, padding=1)\n",
|
|
" return torch.sqrt(grad_x.pow(2) + grad_y.pow(2) + 1e-6)\n",
|
|
"\n",
|
|
" def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:\n",
|
|
" return F.l1_loss(self._edges(pred), self._edges(target))\n",
|
|
"\n",
|
|
"\n",
|
|
"class ImageGAN(nn.Module):\n",
|
|
" \"\"\"Complete pix2pix-style GAN for Google -> Yandex image translation.\"\"\"\n",
|
|
"\n",
|
|
" def __init__(\n",
|
|
" self,\n",
|
|
" input_channels: int = 3,\n",
|
|
" output_channels: int = 3,\n",
|
|
" gan_mode: str = \"lsgan\",\n",
|
|
" lambda_L1: float = 150.0,\n",
|
|
" lambda_GAN: float = 0.5,\n",
|
|
" lambda_SSIM: float = 25.0,\n",
|
|
" lambda_edge: float = 20.0,\n",
|
|
" use_cuda: bool = True,\n",
|
|
" ):\n",
|
|
" super().__init__()\n",
|
|
" self.generator = GeneratorUNet(input_channels, output_channels)\n",
|
|
" self.discriminator = DiscriminatorPatchGAN(input_channels + output_channels)\n",
|
|
" self.gan_loss = GANLoss(gan_mode)\n",
|
|
" self.l1_loss = nn.L1Loss()\n",
|
|
" self.ssim_loss = SSIMLoss()\n",
|
|
" self.edge_loss = SobelEdgeLoss()\n",
|
|
" self.lambda_L1 = lambda_L1\n",
|
|
" self.lambda_GAN = lambda_GAN\n",
|
|
" self.lambda_SSIM = lambda_SSIM\n",
|
|
" self.lambda_edge = lambda_edge\n",
|
|
"\n",
|
|
" self.device = get_compatible_device(prefer_cuda=use_cuda)\n",
|
|
" self.to(self.device)\n",
|
|
"\n",
|
|
" def forward(self, google_image):\n",
|
|
" \"\"\"Generate a Yandex-style image from a Google image.\"\"\"\n",
|
|
" return self.generator(google_image)\n",
|
|
"\n",
|
|
" def generator_step(self, google_img, real_yandex_img):\n",
|
|
" \"\"\"Compute generator losses against the paired original Yandex image.\"\"\"\n",
|
|
" fake_yandex = self.generator(google_img)\n",
|
|
" fake_pred = self.discriminator(google_img, fake_yandex)\n",
|
|
" gan_loss = self.gan_loss(fake_pred, True) * self.lambda_GAN\n",
|
|
" l1_loss = self.l1_loss(fake_yandex, real_yandex_img) * self.lambda_L1\n",
|
|
" ssim_loss = self.ssim_loss(fake_yandex, real_yandex_img) * self.lambda_SSIM\n",
|
|
" edge_loss = self.edge_loss(fake_yandex, real_yandex_img) * self.lambda_edge\n",
|
|
" total_loss = gan_loss + l1_loss + ssim_loss + edge_loss\n",
|
|
" return total_loss, gan_loss, l1_loss, ssim_loss, edge_loss\n",
|
|
"\n",
|
|
" def discriminator_step(self, google_img, real_yandex_img, fake_yandex_img):\n",
|
|
" \"\"\"Compute discriminator losses for real and generated Yandex targets.\"\"\"\n",
|
|
" real_pred = self.discriminator(google_img, real_yandex_img)\n",
|
|
" real_loss = self.gan_loss(real_pred, True)\n",
|
|
" fake_pred = self.discriminator(google_img, fake_yandex_img.detach())\n",
|
|
" fake_loss = self.gan_loss(fake_pred, False)\n",
|
|
" total_loss = (real_loss + fake_loss) * 0.5\n",
|
|
" return total_loss, real_loss, fake_loss\n",
|
|
"\n",
|
|
"\n",
|
|
"def create_gan(\n",
|
|
" input_channels: int = 3,\n",
|
|
" output_channels: int = 3,\n",
|
|
" gan_mode: str = \"lsgan\",\n",
|
|
" lambda_L1: float = 150.0,\n",
|
|
" lambda_GAN: float = 0.5,\n",
|
|
" lambda_SSIM: float = 25.0,\n",
|
|
" lambda_edge: float = 20.0,\n",
|
|
" use_cuda: bool = True,\n",
|
|
") -> ImageGAN:\n",
|
|
" \"\"\"Create a GAN model.\"\"\"\n",
|
|
" return ImageGAN(\n",
|
|
" input_channels=input_channels,\n",
|
|
" output_channels=output_channels,\n",
|
|
" gan_mode=gan_mode,\n",
|
|
" lambda_L1=lambda_L1,\n",
|
|
" lambda_GAN=lambda_GAN,\n",
|
|
" lambda_SSIM=lambda_SSIM,\n",
|
|
" lambda_edge=lambda_edge,\n",
|
|
" use_cuda=use_cuda,\n",
|
|
" )\n",
|
|
"\n",
|
|
"\n",
|
|
"def initialize_weights(model: nn.Module):\n",
|
|
" \"\"\"Initialize model weights.\"\"\"\n",
|
|
" for m in model.modules():\n",
|
|
" if isinstance(m, nn.Conv2d):\n",
|
|
" nn.init.normal_(m.weight.data, 0.0, 0.02)\n",
|
|
" elif isinstance(m, nn.BatchNorm2d):\n",
|
|
" nn.init.normal_(m.weight.data, 1.0, 0.02)\n",
|
|
" nn.init.constant_(m.bias.data, 0.0)\n",
|
|
"\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": "## Training\n\n`GANTrainer` trains discriminator and generator alternately.\n\n**Training step:**\n1. Generate `fake_yandex = G(google_img)`\n2. Train discriminator on real pair `(google_img, yandex_img)` and fake pair `(google_img, fake_yandex)`\n3. Train generator against discriminator and paired Yandex target\n\n**Checkpoint saving:**\n- `best.pth`\n- `epoch_N.pth`\n"
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"\"\"\"Trainer for GAN model.\"\"\"\n",
|
|
"\n",
|
|
"\n",
|
|
"\n",
|
|
"\n",
|
|
"class GANTrainer:\n",
|
|
" \"\"\"Simple GAN trainer.\"\"\"\n",
|
|
"\n",
|
|
" def __init__(\n",
|
|
" self,\n",
|
|
" model: torch.nn.Module,\n",
|
|
" train_loader: DataLoader,\n",
|
|
" val_loader: DataLoader,\n",
|
|
" config: Dict[str, Any],\n",
|
|
" ):\n",
|
|
" self.model = model\n",
|
|
" self.train_loader = train_loader\n",
|
|
" self.val_loader = val_loader\n",
|
|
" self.config = config\n",
|
|
" self.device = model.device\n",
|
|
"\n",
|
|
" # Optimizers\n",
|
|
" lr = config.get(\"learning_rate\", 2e-4)\n",
|
|
" lr_d = config.get(\"discriminator_learning_rate\", lr * config.get(\"discriminator_lr_factor\", 0.5))\n",
|
|
" beta1 = config.get(\"beta1\", 0.5)\n",
|
|
" beta2 = config.get(\"beta2\", 0.999)\n",
|
|
" self.opt_G = torch.optim.Adam(model.generator.parameters(), lr=lr, betas=(beta1, beta2))\n",
|
|
" self.opt_D = torch.optim.Adam(model.discriminator.parameters(), lr=lr_d, betas=(beta1, beta2))\n",
|
|
"\n",
|
|
" # Training state\n",
|
|
" self.current_epoch = 0\n",
|
|
" self.best_val_loss = float(\"inf\")\n",
|
|
" self.g_losses = []\n",
|
|
" self.d_losses = []\n",
|
|
" self.l1_losses = []\n",
|
|
" self.ssim_losses = []\n",
|
|
" self.edge_losses = []\n",
|
|
" self.val_g_losses = []\n",
|
|
" self.val_d_losses = []\n",
|
|
" self.val_l1_losses = []\n",
|
|
" self.val_ssim_losses = []\n",
|
|
" self.val_edge_losses = []\n",
|
|
" self.val_reconstruction_losses = []\n",
|
|
"\n",
|
|
" # Output dir\n",
|
|
" self.output_dir = Path(config.get(\"output_dir\", \"runs/gan\"))\n",
|
|
" self.output_dir.mkdir(parents=True, exist_ok=True)\n",
|
|
" (self.output_dir / \"checkpoints\").mkdir(exist_ok=True)\n",
|
|
"\n",
|
|
" # Save config\n",
|
|
" with open(self.output_dir / \"config.json\", \"w\") as f:\n",
|
|
" json.dump(config, f, indent=2)\n",
|
|
"\n",
|
|
" def train_epoch(self) -> Tuple[float, float]:\n",
|
|
" \"\"\"Train for one epoch.\"\"\"\n",
|
|
" self.model.train()\n",
|
|
" total_g = total_d = 0.0\n",
|
|
" total_l1 = total_ssim = total_edge = 0.0\n",
|
|
" num_batches = len(self.train_loader)\n",
|
|
" d_update_interval = max(1, self.config.get(\"discriminator_update_interval\", 1))\n",
|
|
"\n",
|
|
" pbar = tqdm(enumerate(self.train_loader), total=num_batches, desc=f\"Epoch {self.current_epoch + 1}\")\n",
|
|
" for batch_idx, batch in pbar:\n",
|
|
" google_img = batch[\"google_img\"].to(self.device)\n",
|
|
" yandex_img = batch[\"yandex_img\"].to(self.device)\n",
|
|
"\n",
|
|
" # Train D\n",
|
|
" if batch_idx % d_update_interval == 0:\n",
|
|
" self.opt_D.zero_grad()\n",
|
|
" with torch.no_grad():\n",
|
|
" fake_img = self.model.generator(google_img)\n",
|
|
" d_loss = self.model.discriminator_step(google_img, yandex_img, fake_img)[0]\n",
|
|
" d_loss.backward()\n",
|
|
" self.opt_D.step()\n",
|
|
" else:\n",
|
|
" d_loss = google_img.new_tensor(0.0)\n",
|
|
"\n",
|
|
" # Train G\n",
|
|
" self.opt_G.zero_grad()\n",
|
|
" g_loss, gan_loss, l1_loss, ssim_loss, edge_loss = self.model.generator_step(google_img, yandex_img)\n",
|
|
" g_loss.backward()\n",
|
|
" self.opt_G.step()\n",
|
|
"\n",
|
|
" total_g += g_loss.item()\n",
|
|
" total_d += d_loss.item()\n",
|
|
" total_l1 += l1_loss.item()\n",
|
|
" total_ssim += ssim_loss.item()\n",
|
|
" total_edge += edge_loss.item()\n",
|
|
" pbar.set_postfix({\n",
|
|
" \"g_loss\": g_loss.item(),\n",
|
|
" \"d_loss\": d_loss.item(),\n",
|
|
" \"l1\": l1_loss.item(),\n",
|
|
" \"ssim\": ssim_loss.item(),\n",
|
|
" \"edge\": edge_loss.item(),\n",
|
|
" })\n",
|
|
"\n",
|
|
" avg_g = total_g / num_batches\n",
|
|
" avg_d = total_d / num_batches\n",
|
|
" avg_l1 = total_l1 / num_batches\n",
|
|
" avg_ssim = total_ssim / num_batches\n",
|
|
" avg_edge = total_edge / num_batches\n",
|
|
" self.g_losses.append(avg_g)\n",
|
|
" self.d_losses.append(avg_d)\n",
|
|
" self.l1_losses.append(avg_l1)\n",
|
|
" self.ssim_losses.append(avg_ssim)\n",
|
|
" self.edge_losses.append(avg_edge)\n",
|
|
" return avg_g, avg_d\n",
|
|
"\n",
|
|
" @torch.no_grad()\n",
|
|
" def validate(self) -> Tuple[float, float]:\n",
|
|
" \"\"\"Validate the model.\"\"\"\n",
|
|
" self.model.eval()\n",
|
|
" total_g = total_d = 0.0\n",
|
|
" total_l1 = total_ssim = total_edge = 0.0\n",
|
|
"\n",
|
|
" for batch in tqdm(self.val_loader, desc=\"Val\"):\n",
|
|
" google_img = batch[\"google_img\"].to(self.device)\n",
|
|
" yandex_img = batch[\"yandex_img\"].to(self.device)\n",
|
|
" fake_img = self.model.generator(google_img)\n",
|
|
" g_loss, _, l1_loss, ssim_loss, edge_loss = self.model.generator_step(google_img, yandex_img)\n",
|
|
" d_loss = self.model.discriminator_step(google_img, yandex_img, fake_img)[0]\n",
|
|
" total_g += g_loss.item()\n",
|
|
" total_d += d_loss.item()\n",
|
|
" total_l1 += l1_loss.item()\n",
|
|
" total_ssim += ssim_loss.item()\n",
|
|
" total_edge += edge_loss.item()\n",
|
|
"\n",
|
|
" avg_g = total_g / len(self.val_loader)\n",
|
|
" avg_d = total_d / len(self.val_loader)\n",
|
|
" avg_l1 = total_l1 / len(self.val_loader)\n",
|
|
" avg_ssim = total_ssim / len(self.val_loader)\n",
|
|
" avg_edge = total_edge / len(self.val_loader)\n",
|
|
" avg_reconstruction = avg_l1 + avg_ssim + avg_edge\n",
|
|
" self.val_g_losses.append(avg_g)\n",
|
|
" self.val_d_losses.append(avg_d)\n",
|
|
" self.val_l1_losses.append(avg_l1)\n",
|
|
" self.val_ssim_losses.append(avg_ssim)\n",
|
|
" self.val_edge_losses.append(avg_edge)\n",
|
|
" self.val_reconstruction_losses.append(avg_reconstruction)\n",
|
|
" return avg_g, avg_d\n",
|
|
"\n",
|
|
" def train(self, num_epochs: int):\n",
|
|
" \"\"\"Train the model.\"\"\"\n",
|
|
" print(f\"Training for {num_epochs} epochs...\")\n",
|
|
"\n",
|
|
" for epoch in range(num_epochs):\n",
|
|
" self.current_epoch = epoch\n",
|
|
"\n",
|
|
" # Train & validate\n",
|
|
" train_g, train_d = self.train_epoch()\n",
|
|
" val_g, val_d = self.validate()\n",
|
|
"\n",
|
|
" val_reconstruction = self.val_reconstruction_losses[-1]\n",
|
|
" if val_reconstruction < self.best_val_loss:\n",
|
|
" self.best_val_loss = val_reconstruction\n",
|
|
" self.save_checkpoint(\"best\")\n",
|
|
"\n",
|
|
" # Periodic checkpoint\n",
|
|
" if (epoch + 1) % self.config.get(\"save_interval\", 5) == 0:\n",
|
|
" self.save_checkpoint(f\"epoch_{epoch + 1}\")\n",
|
|
"\n",
|
|
" print(\n",
|
|
" f\"Epoch {epoch + 1}: \"\n",
|
|
" f\"train_g={train_g:.4f}, train_d={train_d:.4f}, \"\n",
|
|
" f\"train_l1={self.l1_losses[-1]:.4f}, train_ssim={self.ssim_losses[-1]:.4f}, \"\n",
|
|
" f\"train_edge={self.edge_losses[-1]:.4f}, val_g={val_g:.4f}, val_d={val_d:.4f}, \"\n",
|
|
" f\"val_l1={self.val_l1_losses[-1]:.4f}, val_ssim={self.val_ssim_losses[-1]:.4f}, \"\n",
|
|
" f\"val_edge={self.val_edge_losses[-1]:.4f}, val_rec={val_reconstruction:.4f}\"\n",
|
|
" )\n",
|
|
"\n",
|
|
" # Early stopping\n",
|
|
" patience = self.config.get(\"early_stopping_patience\", 0)\n",
|
|
" if patience > 0 and len(self.val_reconstruction_losses) > patience:\n",
|
|
" recent = self.val_reconstruction_losses[-patience:]\n",
|
|
" previous_best = min(self.val_reconstruction_losses[:-patience])\n",
|
|
" if all(loss >= previous_best for loss in recent):\n",
|
|
" print(f\"Early stopping at epoch {epoch + 1}\")\n",
|
|
" break\n",
|
|
"\n",
|
|
" # Save final\n",
|
|
" self.save_checkpoint(\"final\")\n",
|
|
" print(f\"Training finished. Best val loss: {self.best_val_loss:.4f}\")\n",
|
|
"\n",
|
|
" def save_checkpoint(self, name: str):\n",
|
|
" \"\"\"Save model checkpoint.\"\"\"\n",
|
|
" path = self.output_dir / \"checkpoints\" / f\"{name}.pth\"\n",
|
|
" torch.save({\n",
|
|
" \"epoch\": self.current_epoch,\n",
|
|
" \"generator\": self.model.generator.state_dict(),\n",
|
|
" \"discriminator\": self.model.discriminator.state_dict(),\n",
|
|
" \"opt_G\": self.opt_G.state_dict(),\n",
|
|
" \"opt_D\": self.opt_D.state_dict(),\n",
|
|
" }, path)\n",
|
|
"\n",
|
|
"\n",
|
|
"def create_trainer(\n",
|
|
" model: torch.nn.Module,\n",
|
|
" train_loader: DataLoader,\n",
|
|
" val_loader: DataLoader,\n",
|
|
" config: Dict[str, Any],\n",
|
|
") -> GANTrainer:\n",
|
|
" \"\"\"Create a trainer instance.\"\"\"\n",
|
|
" return GANTrainer(model, train_loader, val_loader, config)\n",
|
|
"\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": "## Analysis\n\nVisualization helpers for generated samples and collected training metrics.\n\nTraining history plot contains:\n1. Generator loss\n2. Discriminator loss\n3. L1 loss against the paired Yandex target\n4. SSIM structure loss\n5. Sobel edge loss\n6. Best-checkpoint reconstruction score\n\nThe sample grid contains:\n1. Google input\n2. Generated Yandex\n"
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"\n",
|
|
"\n",
|
|
"\n",
|
|
"def denormalize_image(tensor):\n",
|
|
" return (tensor.detach().cpu() * 0.5 + 0.5).clamp(0, 1)\n",
|
|
"\n",
|
|
"\n",
|
|
"@torch.no_grad()\n",
|
|
"def visualize_generation(model, data_loader, output_dir, device=None, num_samples=4, show=True):\n",
|
|
" device = device or model.device\n",
|
|
" model.eval()\n",
|
|
"\n",
|
|
" batch = next(iter(data_loader))\n",
|
|
" google_img = batch[\"google_img\"][:num_samples].to(device)\n",
|
|
" yandex_img = batch[\"yandex_img\"][:num_samples].to(device)\n",
|
|
" fake_yandex = model.generator(google_img)\n",
|
|
"\n",
|
|
" output_dir = Path(output_dir)\n",
|
|
" output_dir.mkdir(parents=True, exist_ok=True)\n",
|
|
" output_path = output_dir / \"generation_samples.png\"\n",
|
|
"\n",
|
|
" fig, axes = plt.subplots(num_samples, 3, figsize=(9, 3 * num_samples))\n",
|
|
" if num_samples == 1:\n",
|
|
" axes = axes.reshape(1, 3)\n",
|
|
"\n",
|
|
" titles = [\"Google input\", \"Generated Yandex\", \"Yandex target\"]\n",
|
|
" for row in range(num_samples):\n",
|
|
" images = [google_img[row], fake_yandex[row], yandex_img[row]]\n",
|
|
" for col, image in enumerate(images):\n",
|
|
" axes[row, col].imshow(denormalize_image(image).permute(1, 2, 0))\n",
|
|
" axes[row, col].set_title(titles[col])\n",
|
|
" axes[row, col].axis(\"off\")\n",
|
|
"\n",
|
|
" fig.tight_layout()\n",
|
|
" fig.savefig(output_path, dpi=150)\n",
|
|
" if show:\n",
|
|
" plt.show()\n",
|
|
" plt.close(fig)\n",
|
|
" return output_path\n",
|
|
"\n",
|
|
"\n",
|
|
"def plot_training_history(trainer, output_dir, show=True):\n",
|
|
" output_dir = Path(output_dir)\n",
|
|
" output_dir.mkdir(parents=True, exist_ok=True)\n",
|
|
" output_path = output_dir / \"training_history.png\"\n",
|
|
"\n",
|
|
" epochs = range(1, len(trainer.g_losses) + 1)\n",
|
|
" fig, axes = plt.subplots(2, 3, figsize=(15, 8))\n",
|
|
" axes = axes.ravel()\n",
|
|
"\n",
|
|
" axes[0].plot(epochs, trainer.g_losses, label=\"train G\")\n",
|
|
" axes[0].plot(epochs, trainer.val_g_losses, label=\"val G\")\n",
|
|
" axes[0].set_title(\"Generator loss\")\n",
|
|
" axes[0].set_xlabel(\"Epoch\")\n",
|
|
" axes[0].legend()\n",
|
|
" axes[0].grid(True, alpha=0.3)\n",
|
|
"\n",
|
|
" axes[1].plot(epochs, trainer.d_losses, label=\"train D\")\n",
|
|
" axes[1].plot(epochs, trainer.val_d_losses, label=\"val D\")\n",
|
|
" axes[1].set_title(\"Discriminator loss\")\n",
|
|
" axes[1].set_xlabel(\"Epoch\")\n",
|
|
" axes[1].legend()\n",
|
|
" axes[1].grid(True, alpha=0.3)\n",
|
|
"\n",
|
|
" axes[2].plot(epochs, trainer.l1_losses, label=\"train L1\")\n",
|
|
" axes[2].plot(epochs, trainer.val_l1_losses, label=\"val L1\")\n",
|
|
" axes[2].set_title(\"Paired Yandex L1 loss\")\n",
|
|
" axes[2].set_xlabel(\"Epoch\")\n",
|
|
" axes[2].legend()\n",
|
|
" axes[2].grid(True, alpha=0.3)\n",
|
|
"\n",
|
|
" axes[3].plot(epochs, trainer.ssim_losses, label=\"train SSIM\")\n",
|
|
" axes[3].plot(epochs, trainer.val_ssim_losses, label=\"val SSIM\")\n",
|
|
" axes[3].set_title(\"SSIM structure loss\")\n",
|
|
" axes[3].set_xlabel(\"Epoch\")\n",
|
|
" axes[3].legend()\n",
|
|
" axes[3].grid(True, alpha=0.3)\n",
|
|
"\n",
|
|
" axes[4].plot(epochs, trainer.edge_losses, label=\"train edge\")\n",
|
|
" axes[4].plot(epochs, trainer.val_edge_losses, label=\"val edge\")\n",
|
|
" axes[4].set_title(\"Sobel edge loss\")\n",
|
|
" axes[4].set_xlabel(\"Epoch\")\n",
|
|
" axes[4].legend()\n",
|
|
" axes[4].grid(True, alpha=0.3)\n",
|
|
"\n",
|
|
" axes[5].plot(epochs, trainer.val_reconstruction_losses, label=\"val reconstruction\")\n",
|
|
" axes[5].set_title(\"Best-checkpoint score\")\n",
|
|
" axes[5].set_xlabel(\"Epoch\")\n",
|
|
" axes[5].legend()\n",
|
|
" axes[5].grid(True, alpha=0.3)\n",
|
|
"\n",
|
|
" fig.tight_layout()\n",
|
|
" fig.savefig(output_path, dpi=150)\n",
|
|
" if show:\n",
|
|
" plt.show()\n",
|
|
" plt.close(fig)\n",
|
|
" return output_path\n",
|
|
"\n",
|
|
"\n",
|
|
"def analyze_training(trainer):\n",
|
|
" return {\n",
|
|
" \"best_val_loss\": trainer.best_val_loss,\n",
|
|
" \"g_losses\": trainer.g_losses,\n",
|
|
" \"d_losses\": trainer.d_losses,\n",
|
|
" \"l1_losses\": trainer.l1_losses,\n",
|
|
" \"ssim_losses\": trainer.ssim_losses,\n",
|
|
" \"edge_losses\": trainer.edge_losses,\n",
|
|
" \"val_g_losses\": trainer.val_g_losses,\n",
|
|
" \"val_d_losses\": trainer.val_d_losses,\n",
|
|
" \"val_l1_losses\": trainer.val_l1_losses,\n",
|
|
" \"val_ssim_losses\": trainer.val_ssim_losses,\n",
|
|
" \"val_edge_losses\": trainer.val_edge_losses,\n",
|
|
" \"val_reconstruction_losses\": trainer.val_reconstruction_losses,\n",
|
|
" }\n",
|
|
"\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": "## Main Pipeline\n\nExecutes the full GAN workflow:\n1. Create config\n2. Build paired data loaders\n3. Initialize Google -> Yandex GAN\n4. Train with validation\n5. Save checkpoints in `runs/checkpoints/`\n6. Show loss plots and generated sample grid\n\nThis block is intentionally top-level, not wrapped in `main()`, so notebook\nvariables such as `model`, `trainer`, `train_loader`, `val_loader`, and\n"
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"\"\"\"Executable GAN training pipeline.\n",
|
|
"\n",
|
|
"The code is intentionally top-level, mirroring the SiaN notebook style:\n",
|
|
"when this file is included in the generated notebook, variables remain\n",
|
|
"available for debugging and interactive experiments.\n",
|
|
"\"\"\"\n",
|
|
"\n",
|
|
"\n",
|
|
"\n",
|
|
"\n",
|
|
"\n",
|
|
"config = create_config()\n",
|
|
"device = get_compatible_device(prefer_cuda=config[\"prefer_cuda\"])\n",
|
|
"print(f\"Using device: {device}\")\n",
|
|
"\n",
|
|
"train_loader, val_loader = create_data_loaders(\n",
|
|
" root_dir=config[\"data_dir\"],\n",
|
|
" batch_size=config[\"batch_size\"],\n",
|
|
" train_split=config[\"train_split\"],\n",
|
|
" image_size=tuple(config[\"image_size\"]),\n",
|
|
" num_workers=config[\"num_workers\"],\n",
|
|
")\n",
|
|
"print(f\"Data loaders created: train={len(train_loader.dataset)}, val={len(val_loader.dataset)}\")\n",
|
|
"\n",
|
|
"model = create_gan(\n",
|
|
" gan_mode=config[\"gan_mode\"],\n",
|
|
" lambda_GAN=config[\"lambda_GAN\"],\n",
|
|
" lambda_L1=config[\"lambda_L1\"],\n",
|
|
" lambda_SSIM=config[\"lambda_SSIM\"],\n",
|
|
" lambda_edge=config[\"lambda_edge\"],\n",
|
|
" use_cuda=(device.type == \"cuda\"),\n",
|
|
")\n",
|
|
"\n",
|
|
"generator_params = sum(p.numel() for p in model.generator.parameters())\n",
|
|
"discriminator_params = sum(p.numel() for p in model.discriminator.parameters())\n",
|
|
"print(f\"Model created: generator={generator_params:,}, discriminator={discriminator_params:,}\")\n",
|
|
"\n",
|
|
"trainer = create_trainer(model, train_loader, val_loader, config)\n",
|
|
"trainer.train(config[\"epochs\"])\n",
|
|
"\n",
|
|
"training_analysis = analyze_training(trainer)\n",
|
|
"images_dir = Path(config[\"output_dir\"]) / \"images\"\n",
|
|
"history_plot_path = plot_training_history(trainer, images_dir)\n",
|
|
"generation_samples_path = visualize_generation(\n",
|
|
" model=model,\n",
|
|
" data_loader=val_loader,\n",
|
|
" output_dir=images_dir,\n",
|
|
" device=device,\n",
|
|
" num_samples=config[\"num_visual_samples\"],\n",
|
|
")\n",
|
|
"\n",
|
|
"print(f\"Training history plot: {history_plot_path}\")\n",
|
|
"print(f\"Generation samples: {generation_samples_path}\")\n",
|
|
"\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"!zip artefacts.zip runs/checkpoints/best.pth runs/images/training_history.png runs/images/generation_samples.png\n",
|
|
"\n"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": ".venv",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"name": "python",
|
|
"version": "3.11.0"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
} |