Files
autopilot/models/GAN/notebook.gen.ipynb
2026-05-30 14:49:40 +03:00

1064 lines
46 KiB
Plaintext

{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import json\n",
"import os\n",
"from pathlib import Path\n",
"from typing import Any, Dict, List, Tuple\n",
"import matplotlib.pyplot as plt\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"from PIL import Image\n",
"from torch.utils.data import DataLoader, Dataset, Subset\n",
"from torchvision import transforms\n",
"from tqdm import tqdm\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": "# Configuration\n\nGlobal settings for the Google -> Yandex GAN:\n- Dataset path and image size\n- Optimizer and training hyperparameters\n- Device preference with safe CUDA compatibility checks\n- GAN, L1, SSIM and edge reconstruction weights\n"
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\"\"\"Configuration for GAN training.\"\"\"\n",
"\n",
"\n",
"def create_config():\n",
" \"\"\"Create default configuration dictionary.\"\"\"\n",
" return {\n",
" # Optimizer params\n",
" \"learning_rate\": 2e-4,\n",
" \"discriminator_lr_factor\": 0.5,\n",
" \"beta1\": 0.5,\n",
" \"beta2\": 0.999,\n",
" # Training params\n",
" \"batch_size\": 32,\n",
" \"epochs\": 100,\n",
" \"prefer_cuda\": True,\n",
" # GAN params\n",
" \"gan_mode\": \"lsgan\",\n",
" \"lambda_GAN\": 0.5,\n",
" \"lambda_L1\": 150.0,\n",
" \"lambda_SSIM\": 25.0,\n",
" \"lambda_edge\": 20.0,\n",
" \"discriminator_update_interval\": 1,\n",
" # Regularization\n",
" \"grad_clip\": 1.0,\n",
" # Early stopping\n",
" \"early_stopping_patience\": 25,\n",
" # Output\n",
" \"output_dir\": r\"C:\\Users\\admin\\Projects\\autopilot\\models\\GAN\\runs\",\n",
" # Logging\n",
" \"log_interval\": 10,\n",
" \"save_interval\": 5,\n",
" \"num_visual_samples\": 4,\n",
" # Data\n",
" \"data_dir\": r\"C:\\Users\\admin\\Projects\\autopilot\\datasets\\ya_go_maps\\images\",\n",
" \"image_size\": [256, 256],\n",
" \"train_split\": 0.8,\n",
" \"num_workers\": 0,\n",
" }\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": "## Dataset\n\nGoogle/Yandex paired image loader.\n\n**Direction:**\n- `google_img` is the generator input\n- `yandex_img` is the target image from the same pair\n\n**Returns:**\n"
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\"\"\"Data loader for Google-to-Yandex image translation.\"\"\"\n",
"\n",
"\n",
"\n",
"\n",
"class YaGoDataset(Dataset):\n",
" \"\"\"Dataset loading paired Google and Yandex map images.\"\"\"\n",
"\n",
" def __init__(\n",
" self,\n",
" root_dir: str,\n",
" image_size: Tuple[int, int] = (256, 256),\n",
" augment: bool = False,\n",
" ):\n",
" \"\"\"\n",
" Args:\n",
" root_dir: Directory with images named {idx:04d}_google.png and {idx:04d}_yandex.png\n",
" image_size: Target image size (height, width)\n",
" augment: Whether to apply augmentation (not implemented for simplicity)\n",
" \"\"\"\n",
" self.root_dir = root_dir\n",
" self.image_size = image_size\n",
" self.augment = augment\n",
"\n",
" # Discover image pairs\n",
" self.pairs = self._find_pairs()\n",
"\n",
" # Transform to tensor + normalization\n",
" self.transform = transforms.Compose(\n",
" [\n",
" transforms.ToTensor(),\n",
" transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),\n",
" ]\n",
" )\n",
"\n",
" def _find_pairs(self) -> List[Dict]:\n",
" \"\"\"Find all matching Google-Yandex image pairs.\"\"\"\n",
" pairs = []\n",
" google_files = [f for f in os.listdir(self.root_dir) if f.endswith(\"_google.png\")]\n",
"\n",
" for google_file in sorted(google_files):\n",
" idx_str = google_file.split(\"_\")[0]\n",
" try:\n",
" idx = int(idx_str)\n",
" except ValueError:\n",
" continue\n",
"\n",
" yandex_file = f\"{idx:04d}_yandex.png\"\n",
" yandex_path = os.path.join(self.root_dir, yandex_file)\n",
"\n",
" if os.path.exists(yandex_path):\n",
" pairs.append(\n",
" {\n",
" \"idx\": idx,\n",
" \"google_path\": os.path.join(self.root_dir, google_file),\n",
" \"yandex_path\": yandex_path,\n",
" }\n",
" )\n",
"\n",
" return pairs\n",
"\n",
" def __len__(self) -> int:\n",
" return len(self.pairs)\n",
"\n",
" def __getitem__(self, idx: int) -> dict:\n",
" pair = self.pairs[idx]\n",
"\n",
" # Load images\n",
" google_img = Image.open(pair[\"google_path\"]).convert(\"RGB\")\n",
" yandex_img = Image.open(pair[\"yandex_path\"]).convert(\"RGB\")\n",
"\n",
" # Resize\n",
" google_img = google_img.resize((self.image_size[1], self.image_size[0]))\n",
" yandex_img = yandex_img.resize((self.image_size[1], self.image_size[0]))\n",
"\n",
" # Apply transforms\n",
" google_tensor = self.transform(google_img)\n",
" yandex_tensor = self.transform(yandex_img)\n",
"\n",
" return {\n",
" \"google_img\": google_tensor,\n",
" \"yandex_img\": yandex_tensor,\n",
" \"idx\": torch.tensor(pair[\"idx\"], dtype=torch.long),\n",
" }\n",
"\n",
"\n",
"def create_data_loaders(\n",
" root_dir: str,\n",
" batch_size: int = 32,\n",
" train_split: float = 0.8,\n",
" num_workers: int = 0,\n",
" image_size: Tuple[int, int] = (256, 256),\n",
") -> Tuple[DataLoader, DataLoader]:\n",
" \"\"\"\n",
" Create train and validation data loaders.\n",
"\n",
" Args:\n",
" root_dir: Directory with image pairs\n",
" batch_size: Batch size\n",
" train_split: Fraction for training (0.0-1.0)\n",
" num_workers: DataLoader workers\n",
" image_size: Target image size\n",
"\n",
" Returns:\n",
" (train_loader, val_loader)\n",
" \"\"\"\n",
" # Full dataset\n",
" dataset = YaGoDataset(root_dir=root_dir, image_size=image_size)\n",
"\n",
" # Split\n",
" dataset_size = len(dataset)\n",
" train_size = int(train_split * dataset_size)\n",
" indices = torch.randperm(dataset_size).tolist()\n",
" train_indices = indices[:train_size]\n",
" val_indices = indices[train_size:]\n",
"\n",
" # Subsets\n",
"\n",
" train_dataset = Subset(dataset, train_indices)\n",
" val_dataset = Subset(dataset, val_indices)\n",
"\n",
" # DataLoaders\n",
" train_loader = DataLoader(\n",
" train_dataset,\n",
" batch_size=batch_size,\n",
" shuffle=True,\n",
" num_workers=num_workers,\n",
" pin_memory=True,\n",
" )\n",
"\n",
" val_loader = DataLoader(\n",
" val_dataset,\n",
" batch_size=batch_size,\n",
" shuffle=False,\n",
" num_workers=num_workers,\n",
" pin_memory=True,\n",
" )\n",
"\n",
" return train_loader, val_loader\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n",
"\n",
"def get_dataset_info():\n",
" config = create_config()\n",
" dataset = YaGoDataset(\n",
" root_dir=config[\"data_dir\"],\n",
" image_size=tuple(config[\"image_size\"]),\n",
" )\n",
" sample = dataset[0] if len(dataset) else {}\n",
" return {\n",
" \"size\": len(dataset),\n",
" \"sample_keys\": list(sample.keys()),\n",
" \"google_shape\": tuple(sample[\"google_img\"].shape) if sample else None,\n",
" \"yandex_shape\": tuple(sample[\"yandex_img\"].shape) if sample else None,\n",
" }\n",
"\n",
"\n",
"def smoke_test_dataloader(batch_size=4):\n",
" config = create_config()\n",
" train_loader, val_loader = create_data_loaders(\n",
" root_dir=config[\"data_dir\"],\n",
" batch_size=batch_size,\n",
" train_split=config[\"train_split\"],\n",
" num_workers=config[\"num_workers\"],\n",
" image_size=tuple(config[\"image_size\"]),\n",
" )\n",
" batch = next(iter(train_loader))\n",
" return {\n",
" \"train_size\": len(train_loader.dataset),\n",
" \"val_size\": len(val_loader.dataset),\n",
" \"google_batch_shape\": tuple(batch[\"google_img\"].shape),\n",
" \"yandex_batch_shape\": tuple(batch[\"yandex_img\"].shape),\n",
" }\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": "## Model\n\nPix2pix-style GAN for Google -> Yandex map translation.\n\n**Generator:**\n- `GeneratorUNet`\n- Input: Google image `(B, 3, H, W)`\n- Output: generated Yandex image `(B, 3, H, W)`\n\n**Discriminator:**\n- `DiscriminatorPatchGAN`\n- Input pair: `(google_img, yandex_img)`\n- Learns to distinguish real pairs from `(google_img, fake_yandex)`\n\n**Generator loss:**\n- adversarial loss\n- `lambda_L1 * L1(fake_yandex, yandex_img)`\n- `lambda_SSIM * SSIMLoss(fake_yandex, yandex_img)`\n- `lambda_edge * SobelEdgeLoss(fake_yandex, yandex_img)`\n\nThe generator uses bilinear upsampling followed by convolution to avoid\n"
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\"\"\"GAN model for image translation Google -> Yandex.\"\"\"\n",
"\n",
"\n",
"\n",
"def get_compatible_device(prefer_cuda: bool = True, verbose: bool = True) -> torch.device:\n",
" \"\"\"Return CUDA only when the current PyTorch build supports the GPU arch.\"\"\"\n",
" if not prefer_cuda or not torch.cuda.is_available():\n",
" return torch.device(\"cpu\")\n",
"\n",
" try:\n",
" major, minor = torch.cuda.get_device_capability()\n",
" arch = f\"sm_{major}{minor}\"\n",
" supported_arches = set(torch.cuda.get_arch_list())\n",
" gpu_name = torch.cuda.get_device_name()\n",
" except Exception as exc:\n",
" if verbose:\n",
" print(f\"CUDA is visible but cannot be inspected ({exc}); using CPU.\")\n",
" return torch.device(\"cpu\")\n",
"\n",
" if supported_arches and arch not in supported_arches:\n",
" if verbose:\n",
" supported = \", \".join(sorted(supported_arches))\n",
" print(\n",
" f\"CUDA GPU '{gpu_name}' has capability {arch}, but this PyTorch build \"\n",
" f\"supports only: {supported}. Using CPU.\"\n",
" )\n",
" return torch.device(\"cpu\")\n",
"\n",
" return torch.device(\"cuda\")\n",
"\n",
"\n",
"class UNetDownBlock(nn.Module):\n",
" \"\"\"Downsampling block for U-Net.\"\"\"\n",
"\n",
" def __init__(self, in_channels: int, out_channels: int, normalize: bool = True, dropout: float = 0.0):\n",
" super().__init__()\n",
" layers = [\n",
" nn.Conv2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1, bias=False)\n",
" ]\n",
" if normalize:\n",
" layers.append(nn.BatchNorm2d(out_channels))\n",
" layers.append(nn.LeakyReLU(0.2, inplace=True))\n",
" if dropout > 0:\n",
" layers.append(nn.Dropout2d(dropout))\n",
" self.model = nn.Sequential(*layers)\n",
"\n",
" def forward(self, x):\n",
" return self.model(x)\n",
"\n",
"\n",
"class UNetUpBlock(nn.Module):\n",
" \"\"\"Upsampling block for U-Net.\"\"\"\n",
"\n",
" def __init__(self, in_channels: int, out_channels: int, dropout: float = 0.0):\n",
" super().__init__()\n",
" self.upconv = nn.Sequential(\n",
" nn.Upsample(scale_factor=2, mode=\"bilinear\", align_corners=False),\n",
" nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),\n",
" )\n",
" self.norm = nn.BatchNorm2d(out_channels)\n",
" self.relu = nn.ReLU(inplace=True)\n",
" if dropout > 0:\n",
" self.dropout = nn.Dropout2d(dropout)\n",
" else:\n",
" self.dropout = None\n",
"\n",
" def forward(self, x, skip_input):\n",
" x = self.upconv(x)\n",
" # Pad if needed to match skip connection size\n",
" if x.shape != skip_input.shape:\n",
" diff_h = skip_input.size(2) - x.size(2)\n",
" diff_w = skip_input.size(3) - x.size(3)\n",
" x = F.pad(x, [diff_w // 2, diff_w - diff_w // 2, diff_h // 2, diff_h - diff_h // 2])\n",
" x = self.norm(x)\n",
" x = self.relu(x)\n",
" if self.dropout:\n",
" x = self.dropout(x)\n",
" x = torch.cat([x, skip_input], dim=1)\n",
" return x\n",
"\n",
"\n",
"class GeneratorUNet(nn.Module):\n",
" \"\"\"U-Net generator for Google -> Yandex translation.\"\"\"\n",
"\n",
" def __init__(self, in_channels: int = 3, out_channels: int = 3):\n",
" super().__init__()\n",
"\n",
" # Downsampling\n",
" self.down1 = UNetDownBlock(in_channels, 64, normalize=False)\n",
" self.down2 = UNetDownBlock(64, 128)\n",
" self.down3 = UNetDownBlock(128, 256)\n",
" self.down4 = UNetDownBlock(256, 512)\n",
" self.down5 = UNetDownBlock(512, 512)\n",
" self.down6 = UNetDownBlock(512, 512)\n",
" self.down7 = UNetDownBlock(512, 512)\n",
"\n",
" # Bottleneck\n",
" self.bottleneck = nn.Sequential(\n",
" nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1, bias=False),\n",
" nn.ReLU(inplace=True),\n",
" )\n",
"\n",
" # Upsampling - input channels from previous layer, output before concat\n",
" self.up1 = UNetUpBlock(512, 512, dropout=0.5) # in: 512 (bottleneck) -> out: 512, concat with d7 (512) = 1024\n",
" self.up2 = UNetUpBlock(1024, 512, dropout=0.5) # in: 1024 -> out: 512, concat with d6 (512) = 1024\n",
" self.up3 = UNetUpBlock(1024, 512, dropout=0.5) # in: 1024 -> out: 512, concat with d5 (512) = 1024\n",
" self.up4 = UNetUpBlock(1024, 512) # in: 1024 -> out: 512, concat with d4 (512) = 1024\n",
" self.up5 = UNetUpBlock(1024, 256) # in: 1024 -> out: 256, concat with d3 (256) = 512\n",
" self.up6 = UNetUpBlock(512, 128) # in: 512 -> out: 128, concat with d2 (128) = 256\n",
" self.up7 = UNetUpBlock(256, 64) # in: 256 -> out: 64, concat with d1 (64) = 128\n",
"\n",
" # Final\n",
" self.final = nn.Sequential(\n",
" nn.Upsample(scale_factor=2, mode=\"bilinear\", align_corners=False),\n",
" nn.Conv2d(128, out_channels, kernel_size=3, padding=1),\n",
" nn.Tanh(),\n",
" )\n",
"\n",
" def forward(self, x):\n",
" # Down\n",
" d1 = self.down1(x)\n",
" d2 = self.down2(d1)\n",
" d3 = self.down3(d2)\n",
" d4 = self.down4(d3)\n",
" d5 = self.down5(d4)\n",
" d6 = self.down6(d5)\n",
" d7 = self.down7(d6)\n",
"\n",
" # Bottleneck\n",
" u = self.bottleneck(d7)\n",
"\n",
" # Up with skip connections\n",
" u = self.up1(u, d7)\n",
" u = self.up2(u, d6)\n",
" u = self.up3(u, d5)\n",
" u = self.up4(u, d4)\n",
" u = self.up5(u, d3)\n",
" u = self.up6(u, d2)\n",
" u = self.up7(u, d1)\n",
"\n",
" return self.final(u)\n",
"\n",
"\n",
"class DiscriminatorPatchGAN(nn.Module):\n",
" \"\"\"PatchGAN discriminator for paired source/target images.\"\"\"\n",
"\n",
" def __init__(self, in_channels: int = 6):\n",
" super().__init__()\n",
" self.model = nn.Sequential(\n",
" nn.Conv2d(in_channels, 64, kernel_size=4, stride=2, padding=1),\n",
" nn.LeakyReLU(0.2, inplace=True),\n",
" nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),\n",
" nn.BatchNorm2d(128),\n",
" nn.LeakyReLU(0.2, inplace=True),\n",
" nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),\n",
" nn.BatchNorm2d(256),\n",
" nn.LeakyReLU(0.2, inplace=True),\n",
" nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1),\n",
" nn.BatchNorm2d(512),\n",
" nn.LeakyReLU(0.2, inplace=True),\n",
" nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=1),\n",
" )\n",
"\n",
" def forward(self, img_A, img_B):\n",
" x = torch.cat([img_A, img_B], dim=1)\n",
" return self.model(x)\n",
"\n",
"\n",
"class GANLoss(nn.Module):\n",
" \"\"\"GAN loss supporting different GAN modes.\"\"\"\n",
"\n",
" def __init__(self, gan_mode: str = \"vanilla\", target_real: float = 1.0, target_fake: float = 0.0):\n",
" super().__init__()\n",
" self.gan_mode = gan_mode\n",
" self.register_buffer(\"real_label\", torch.tensor(target_real))\n",
" self.register_buffer(\"fake_label\", torch.tensor(target_fake))\n",
"\n",
" if gan_mode == \"vanilla\":\n",
" self.loss_fn = nn.BCEWithLogitsLoss()\n",
" elif gan_mode == \"lsgan\":\n",
" self.loss_fn = nn.MSELoss()\n",
" elif gan_mode == \"wgangp\":\n",
" self.loss_fn = None\n",
" else:\n",
" raise ValueError(f\"Unknown GAN mode: {gan_mode}\")\n",
"\n",
" def forward(self, prediction: torch.Tensor, target_is_real: bool) -> torch.Tensor:\n",
" if self.gan_mode in [\"vanilla\", \"lsgan\"]:\n",
" target = self.real_label if target_is_real else self.fake_label\n",
" target = target.expand_as(prediction)\n",
" return self.loss_fn(prediction, target)\n",
" elif self.gan_mode == \"wgangp\":\n",
" return -prediction.mean() if target_is_real else prediction.mean()\n",
"\n",
"\n",
"class SSIMLoss(nn.Module):\n",
" \"\"\"Local SSIM loss for normalized image tensors in [-1, 1].\"\"\"\n",
"\n",
" def __init__(self, window_size: int = 11, c1: float = 0.01 ** 2, c2: float = 0.03 ** 2):\n",
" super().__init__()\n",
" self.window_size = window_size\n",
" self.c1 = c1\n",
" self.c2 = c2\n",
"\n",
" def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:\n",
" pred = (pred + 1.0) * 0.5\n",
" target = (target + 1.0) * 0.5\n",
" padding = self.window_size // 2\n",
"\n",
" mu_pred = F.avg_pool2d(pred, self.window_size, stride=1, padding=padding)\n",
" mu_target = F.avg_pool2d(target, self.window_size, stride=1, padding=padding)\n",
" mu_pred_sq = mu_pred.pow(2)\n",
" mu_target_sq = mu_target.pow(2)\n",
" mu_pred_target = mu_pred * mu_target\n",
"\n",
" sigma_pred = F.avg_pool2d(pred * pred, self.window_size, stride=1, padding=padding) - mu_pred_sq\n",
" sigma_target = F.avg_pool2d(target * target, self.window_size, stride=1, padding=padding) - mu_target_sq\n",
" sigma_pred_target = F.avg_pool2d(pred * target, self.window_size, stride=1, padding=padding) - mu_pred_target\n",
"\n",
" ssim_map = (\n",
" (2 * mu_pred_target + self.c1) * (2 * sigma_pred_target + self.c2)\n",
" ) / (\n",
" (mu_pred_sq + mu_target_sq + self.c1) * (sigma_pred + sigma_target + self.c2)\n",
" )\n",
" return (1.0 - ssim_map.clamp(0, 1)).mean()\n",
"\n",
"\n",
"class SobelEdgeLoss(nn.Module):\n",
" \"\"\"L1 loss between Sobel edge maps, useful for stable keypoint structure.\"\"\"\n",
"\n",
" def __init__(self):\n",
" super().__init__()\n",
" kernel_x = torch.tensor(\n",
" [[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]],\n",
" dtype=torch.float32,\n",
" ).view(1, 1, 3, 3)\n",
" kernel_y = torch.tensor(\n",
" [[-1, -2, -1], [0, 0, 0], [1, 2, 1]],\n",
" dtype=torch.float32,\n",
" ).view(1, 1, 3, 3)\n",
" self.register_buffer(\"kernel_x\", kernel_x)\n",
" self.register_buffer(\"kernel_y\", kernel_y)\n",
"\n",
" @staticmethod\n",
" def _to_gray(x: torch.Tensor) -> torch.Tensor:\n",
" x = (x + 1.0) * 0.5\n",
" weights = x.new_tensor([0.299, 0.587, 0.114]).view(1, 3, 1, 1)\n",
" return (x * weights).sum(dim=1, keepdim=True)\n",
"\n",
" def _edges(self, x: torch.Tensor) -> torch.Tensor:\n",
" gray = self._to_gray(x)\n",
" grad_x = F.conv2d(gray, self.kernel_x, padding=1)\n",
" grad_y = F.conv2d(gray, self.kernel_y, padding=1)\n",
" return torch.sqrt(grad_x.pow(2) + grad_y.pow(2) + 1e-6)\n",
"\n",
" def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:\n",
" return F.l1_loss(self._edges(pred), self._edges(target))\n",
"\n",
"\n",
"class ImageGAN(nn.Module):\n",
" \"\"\"Complete pix2pix-style GAN for Google -> Yandex image translation.\"\"\"\n",
"\n",
" def __init__(\n",
" self,\n",
" input_channels: int = 3,\n",
" output_channels: int = 3,\n",
" gan_mode: str = \"lsgan\",\n",
" lambda_L1: float = 150.0,\n",
" lambda_GAN: float = 0.5,\n",
" lambda_SSIM: float = 25.0,\n",
" lambda_edge: float = 20.0,\n",
" use_cuda: bool = True,\n",
" ):\n",
" super().__init__()\n",
" self.generator = GeneratorUNet(input_channels, output_channels)\n",
" self.discriminator = DiscriminatorPatchGAN(input_channels + output_channels)\n",
" self.gan_loss = GANLoss(gan_mode)\n",
" self.l1_loss = nn.L1Loss()\n",
" self.ssim_loss = SSIMLoss()\n",
" self.edge_loss = SobelEdgeLoss()\n",
" self.lambda_L1 = lambda_L1\n",
" self.lambda_GAN = lambda_GAN\n",
" self.lambda_SSIM = lambda_SSIM\n",
" self.lambda_edge = lambda_edge\n",
"\n",
" self.device = get_compatible_device(prefer_cuda=use_cuda)\n",
" self.to(self.device)\n",
"\n",
" def forward(self, google_image):\n",
" \"\"\"Generate a Yandex-style image from a Google image.\"\"\"\n",
" return self.generator(google_image)\n",
"\n",
" def generator_step(self, google_img, real_yandex_img):\n",
" \"\"\"Compute generator losses against the paired original Yandex image.\"\"\"\n",
" fake_yandex = self.generator(google_img)\n",
" fake_pred = self.discriminator(google_img, fake_yandex)\n",
" gan_loss = self.gan_loss(fake_pred, True) * self.lambda_GAN\n",
" l1_loss = self.l1_loss(fake_yandex, real_yandex_img) * self.lambda_L1\n",
" ssim_loss = self.ssim_loss(fake_yandex, real_yandex_img) * self.lambda_SSIM\n",
" edge_loss = self.edge_loss(fake_yandex, real_yandex_img) * self.lambda_edge\n",
" total_loss = gan_loss + l1_loss + ssim_loss + edge_loss\n",
" return total_loss, gan_loss, l1_loss, ssim_loss, edge_loss\n",
"\n",
" def discriminator_step(self, google_img, real_yandex_img, fake_yandex_img):\n",
" \"\"\"Compute discriminator losses for real and generated Yandex targets.\"\"\"\n",
" real_pred = self.discriminator(google_img, real_yandex_img)\n",
" real_loss = self.gan_loss(real_pred, True)\n",
" fake_pred = self.discriminator(google_img, fake_yandex_img.detach())\n",
" fake_loss = self.gan_loss(fake_pred, False)\n",
" total_loss = (real_loss + fake_loss) * 0.5\n",
" return total_loss, real_loss, fake_loss\n",
"\n",
"\n",
"def create_gan(\n",
" input_channels: int = 3,\n",
" output_channels: int = 3,\n",
" gan_mode: str = \"lsgan\",\n",
" lambda_L1: float = 150.0,\n",
" lambda_GAN: float = 0.5,\n",
" lambda_SSIM: float = 25.0,\n",
" lambda_edge: float = 20.0,\n",
" use_cuda: bool = True,\n",
") -> ImageGAN:\n",
" \"\"\"Create a GAN model.\"\"\"\n",
" return ImageGAN(\n",
" input_channels=input_channels,\n",
" output_channels=output_channels,\n",
" gan_mode=gan_mode,\n",
" lambda_L1=lambda_L1,\n",
" lambda_GAN=lambda_GAN,\n",
" lambda_SSIM=lambda_SSIM,\n",
" lambda_edge=lambda_edge,\n",
" use_cuda=use_cuda,\n",
" )\n",
"\n",
"\n",
"def initialize_weights(model: nn.Module):\n",
" \"\"\"Initialize model weights.\"\"\"\n",
" for m in model.modules():\n",
" if isinstance(m, nn.Conv2d):\n",
" nn.init.normal_(m.weight.data, 0.0, 0.02)\n",
" elif isinstance(m, nn.BatchNorm2d):\n",
" nn.init.normal_(m.weight.data, 1.0, 0.02)\n",
" nn.init.constant_(m.bias.data, 0.0)\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": "## Training\n\n`GANTrainer` trains discriminator and generator alternately.\n\n**Training step:**\n1. Generate `fake_yandex = G(google_img)`\n2. Train discriminator on real pair `(google_img, yandex_img)` and fake pair `(google_img, fake_yandex)`\n3. Train generator against discriminator and paired Yandex target\n\n**Checkpoint saving:**\n- `best.pth`\n- `epoch_N.pth`\n"
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\"\"\"Trainer for GAN model.\"\"\"\n",
"\n",
"\n",
"\n",
"\n",
"class GANTrainer:\n",
" \"\"\"Simple GAN trainer.\"\"\"\n",
"\n",
" def __init__(\n",
" self,\n",
" model: torch.nn.Module,\n",
" train_loader: DataLoader,\n",
" val_loader: DataLoader,\n",
" config: Dict[str, Any],\n",
" ):\n",
" self.model = model\n",
" self.train_loader = train_loader\n",
" self.val_loader = val_loader\n",
" self.config = config\n",
" self.device = model.device\n",
"\n",
" # Optimizers\n",
" lr = config.get(\"learning_rate\", 2e-4)\n",
" lr_d = config.get(\"discriminator_learning_rate\", lr * config.get(\"discriminator_lr_factor\", 0.5))\n",
" beta1 = config.get(\"beta1\", 0.5)\n",
" beta2 = config.get(\"beta2\", 0.999)\n",
" self.opt_G = torch.optim.Adam(model.generator.parameters(), lr=lr, betas=(beta1, beta2))\n",
" self.opt_D = torch.optim.Adam(model.discriminator.parameters(), lr=lr_d, betas=(beta1, beta2))\n",
"\n",
" # Training state\n",
" self.current_epoch = 0\n",
" self.best_val_loss = float(\"inf\")\n",
" self.g_losses = []\n",
" self.d_losses = []\n",
" self.l1_losses = []\n",
" self.ssim_losses = []\n",
" self.edge_losses = []\n",
" self.val_g_losses = []\n",
" self.val_d_losses = []\n",
" self.val_l1_losses = []\n",
" self.val_ssim_losses = []\n",
" self.val_edge_losses = []\n",
" self.val_reconstruction_losses = []\n",
"\n",
" # Output dir\n",
" self.output_dir = Path(config.get(\"output_dir\", \"runs/gan\"))\n",
" self.output_dir.mkdir(parents=True, exist_ok=True)\n",
" (self.output_dir / \"checkpoints\").mkdir(exist_ok=True)\n",
"\n",
" # Save config\n",
" with open(self.output_dir / \"config.json\", \"w\") as f:\n",
" json.dump(config, f, indent=2)\n",
"\n",
" def train_epoch(self) -> Tuple[float, float]:\n",
" \"\"\"Train for one epoch.\"\"\"\n",
" self.model.train()\n",
" total_g = total_d = 0.0\n",
" total_l1 = total_ssim = total_edge = 0.0\n",
" num_batches = len(self.train_loader)\n",
" d_update_interval = max(1, self.config.get(\"discriminator_update_interval\", 1))\n",
"\n",
" pbar = tqdm(enumerate(self.train_loader), total=num_batches, desc=f\"Epoch {self.current_epoch + 1}\")\n",
" for batch_idx, batch in pbar:\n",
" google_img = batch[\"google_img\"].to(self.device)\n",
" yandex_img = batch[\"yandex_img\"].to(self.device)\n",
"\n",
" # Train D\n",
" if batch_idx % d_update_interval == 0:\n",
" self.opt_D.zero_grad()\n",
" with torch.no_grad():\n",
" fake_img = self.model.generator(google_img)\n",
" d_loss = self.model.discriminator_step(google_img, yandex_img, fake_img)[0]\n",
" d_loss.backward()\n",
" self.opt_D.step()\n",
" else:\n",
" d_loss = google_img.new_tensor(0.0)\n",
"\n",
" # Train G\n",
" self.opt_G.zero_grad()\n",
" g_loss, gan_loss, l1_loss, ssim_loss, edge_loss = self.model.generator_step(google_img, yandex_img)\n",
" g_loss.backward()\n",
" self.opt_G.step()\n",
"\n",
" total_g += g_loss.item()\n",
" total_d += d_loss.item()\n",
" total_l1 += l1_loss.item()\n",
" total_ssim += ssim_loss.item()\n",
" total_edge += edge_loss.item()\n",
" pbar.set_postfix({\n",
" \"g_loss\": g_loss.item(),\n",
" \"d_loss\": d_loss.item(),\n",
" \"l1\": l1_loss.item(),\n",
" \"ssim\": ssim_loss.item(),\n",
" \"edge\": edge_loss.item(),\n",
" })\n",
"\n",
" avg_g = total_g / num_batches\n",
" avg_d = total_d / num_batches\n",
" avg_l1 = total_l1 / num_batches\n",
" avg_ssim = total_ssim / num_batches\n",
" avg_edge = total_edge / num_batches\n",
" self.g_losses.append(avg_g)\n",
" self.d_losses.append(avg_d)\n",
" self.l1_losses.append(avg_l1)\n",
" self.ssim_losses.append(avg_ssim)\n",
" self.edge_losses.append(avg_edge)\n",
" return avg_g, avg_d\n",
"\n",
" @torch.no_grad()\n",
" def validate(self) -> Tuple[float, float]:\n",
" \"\"\"Validate the model.\"\"\"\n",
" self.model.eval()\n",
" total_g = total_d = 0.0\n",
" total_l1 = total_ssim = total_edge = 0.0\n",
"\n",
" for batch in tqdm(self.val_loader, desc=\"Val\"):\n",
" google_img = batch[\"google_img\"].to(self.device)\n",
" yandex_img = batch[\"yandex_img\"].to(self.device)\n",
" fake_img = self.model.generator(google_img)\n",
" g_loss, _, l1_loss, ssim_loss, edge_loss = self.model.generator_step(google_img, yandex_img)\n",
" d_loss = self.model.discriminator_step(google_img, yandex_img, fake_img)[0]\n",
" total_g += g_loss.item()\n",
" total_d += d_loss.item()\n",
" total_l1 += l1_loss.item()\n",
" total_ssim += ssim_loss.item()\n",
" total_edge += edge_loss.item()\n",
"\n",
" avg_g = total_g / len(self.val_loader)\n",
" avg_d = total_d / len(self.val_loader)\n",
" avg_l1 = total_l1 / len(self.val_loader)\n",
" avg_ssim = total_ssim / len(self.val_loader)\n",
" avg_edge = total_edge / len(self.val_loader)\n",
" avg_reconstruction = avg_l1 + avg_ssim + avg_edge\n",
" self.val_g_losses.append(avg_g)\n",
" self.val_d_losses.append(avg_d)\n",
" self.val_l1_losses.append(avg_l1)\n",
" self.val_ssim_losses.append(avg_ssim)\n",
" self.val_edge_losses.append(avg_edge)\n",
" self.val_reconstruction_losses.append(avg_reconstruction)\n",
" return avg_g, avg_d\n",
"\n",
" def train(self, num_epochs: int):\n",
" \"\"\"Train the model.\"\"\"\n",
" print(f\"Training for {num_epochs} epochs...\")\n",
"\n",
" for epoch in range(num_epochs):\n",
" self.current_epoch = epoch\n",
"\n",
" # Train & validate\n",
" train_g, train_d = self.train_epoch()\n",
" val_g, val_d = self.validate()\n",
"\n",
" val_reconstruction = self.val_reconstruction_losses[-1]\n",
" if val_reconstruction < self.best_val_loss:\n",
" self.best_val_loss = val_reconstruction\n",
" self.save_checkpoint(\"best\")\n",
"\n",
" # Periodic checkpoint\n",
" if (epoch + 1) % self.config.get(\"save_interval\", 5) == 0:\n",
" self.save_checkpoint(f\"epoch_{epoch + 1}\")\n",
"\n",
" print(\n",
" f\"Epoch {epoch + 1}: \"\n",
" f\"train_g={train_g:.4f}, train_d={train_d:.4f}, \"\n",
" f\"train_l1={self.l1_losses[-1]:.4f}, train_ssim={self.ssim_losses[-1]:.4f}, \"\n",
" f\"train_edge={self.edge_losses[-1]:.4f}, val_g={val_g:.4f}, val_d={val_d:.4f}, \"\n",
" f\"val_l1={self.val_l1_losses[-1]:.4f}, val_ssim={self.val_ssim_losses[-1]:.4f}, \"\n",
" f\"val_edge={self.val_edge_losses[-1]:.4f}, val_rec={val_reconstruction:.4f}\"\n",
" )\n",
"\n",
" # Early stopping\n",
" patience = self.config.get(\"early_stopping_patience\", 0)\n",
" if patience > 0 and len(self.val_reconstruction_losses) > patience:\n",
" recent = self.val_reconstruction_losses[-patience:]\n",
" previous_best = min(self.val_reconstruction_losses[:-patience])\n",
" if all(loss >= previous_best for loss in recent):\n",
" print(f\"Early stopping at epoch {epoch + 1}\")\n",
" break\n",
"\n",
" # Save final\n",
" self.save_checkpoint(\"final\")\n",
" print(f\"Training finished. Best val loss: {self.best_val_loss:.4f}\")\n",
"\n",
" def save_checkpoint(self, name: str):\n",
" \"\"\"Save model checkpoint.\"\"\"\n",
" path = self.output_dir / \"checkpoints\" / f\"{name}.pth\"\n",
" torch.save({\n",
" \"epoch\": self.current_epoch,\n",
" \"generator\": self.model.generator.state_dict(),\n",
" \"discriminator\": self.model.discriminator.state_dict(),\n",
" \"opt_G\": self.opt_G.state_dict(),\n",
" \"opt_D\": self.opt_D.state_dict(),\n",
" }, path)\n",
"\n",
"\n",
"def create_trainer(\n",
" model: torch.nn.Module,\n",
" train_loader: DataLoader,\n",
" val_loader: DataLoader,\n",
" config: Dict[str, Any],\n",
") -> GANTrainer:\n",
" \"\"\"Create a trainer instance.\"\"\"\n",
" return GANTrainer(model, train_loader, val_loader, config)\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": "## Analysis\n\nVisualization helpers for generated samples and collected training metrics.\n\nTraining history plot contains:\n1. Generator loss\n2. Discriminator loss\n3. L1 loss against the paired Yandex target\n4. SSIM structure loss\n5. Sobel edge loss\n6. Best-checkpoint reconstruction score\n\nThe sample grid contains:\n1. Google input\n2. Generated Yandex\n"
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n",
"\n",
"\n",
"def denormalize_image(tensor):\n",
" return (tensor.detach().cpu() * 0.5 + 0.5).clamp(0, 1)\n",
"\n",
"\n",
"@torch.no_grad()\n",
"def visualize_generation(model, data_loader, output_dir, device=None, num_samples=4, show=True):\n",
" device = device or model.device\n",
" model.eval()\n",
"\n",
" batch = next(iter(data_loader))\n",
" google_img = batch[\"google_img\"][:num_samples].to(device)\n",
" yandex_img = batch[\"yandex_img\"][:num_samples].to(device)\n",
" fake_yandex = model.generator(google_img)\n",
"\n",
" output_dir = Path(output_dir)\n",
" output_dir.mkdir(parents=True, exist_ok=True)\n",
" output_path = output_dir / \"generation_samples.png\"\n",
"\n",
" fig, axes = plt.subplots(num_samples, 3, figsize=(9, 3 * num_samples))\n",
" if num_samples == 1:\n",
" axes = axes.reshape(1, 3)\n",
"\n",
" titles = [\"Google input\", \"Generated Yandex\", \"Yandex target\"]\n",
" for row in range(num_samples):\n",
" images = [google_img[row], fake_yandex[row], yandex_img[row]]\n",
" for col, image in enumerate(images):\n",
" axes[row, col].imshow(denormalize_image(image).permute(1, 2, 0))\n",
" axes[row, col].set_title(titles[col])\n",
" axes[row, col].axis(\"off\")\n",
"\n",
" fig.tight_layout()\n",
" fig.savefig(output_path, dpi=150)\n",
" if show:\n",
" plt.show()\n",
" plt.close(fig)\n",
" return output_path\n",
"\n",
"\n",
"def plot_training_history(trainer, output_dir, show=True):\n",
" output_dir = Path(output_dir)\n",
" output_dir.mkdir(parents=True, exist_ok=True)\n",
" output_path = output_dir / \"training_history.png\"\n",
"\n",
" epochs = range(1, len(trainer.g_losses) + 1)\n",
" fig, axes = plt.subplots(2, 3, figsize=(15, 8))\n",
" axes = axes.ravel()\n",
"\n",
" axes[0].plot(epochs, trainer.g_losses, label=\"train G\")\n",
" axes[0].plot(epochs, trainer.val_g_losses, label=\"val G\")\n",
" axes[0].set_title(\"Generator loss\")\n",
" axes[0].set_xlabel(\"Epoch\")\n",
" axes[0].legend()\n",
" axes[0].grid(True, alpha=0.3)\n",
"\n",
" axes[1].plot(epochs, trainer.d_losses, label=\"train D\")\n",
" axes[1].plot(epochs, trainer.val_d_losses, label=\"val D\")\n",
" axes[1].set_title(\"Discriminator loss\")\n",
" axes[1].set_xlabel(\"Epoch\")\n",
" axes[1].legend()\n",
" axes[1].grid(True, alpha=0.3)\n",
"\n",
" axes[2].plot(epochs, trainer.l1_losses, label=\"train L1\")\n",
" axes[2].plot(epochs, trainer.val_l1_losses, label=\"val L1\")\n",
" axes[2].set_title(\"Paired Yandex L1 loss\")\n",
" axes[2].set_xlabel(\"Epoch\")\n",
" axes[2].legend()\n",
" axes[2].grid(True, alpha=0.3)\n",
"\n",
" axes[3].plot(epochs, trainer.ssim_losses, label=\"train SSIM\")\n",
" axes[3].plot(epochs, trainer.val_ssim_losses, label=\"val SSIM\")\n",
" axes[3].set_title(\"SSIM structure loss\")\n",
" axes[3].set_xlabel(\"Epoch\")\n",
" axes[3].legend()\n",
" axes[3].grid(True, alpha=0.3)\n",
"\n",
" axes[4].plot(epochs, trainer.edge_losses, label=\"train edge\")\n",
" axes[4].plot(epochs, trainer.val_edge_losses, label=\"val edge\")\n",
" axes[4].set_title(\"Sobel edge loss\")\n",
" axes[4].set_xlabel(\"Epoch\")\n",
" axes[4].legend()\n",
" axes[4].grid(True, alpha=0.3)\n",
"\n",
" axes[5].plot(epochs, trainer.val_reconstruction_losses, label=\"val reconstruction\")\n",
" axes[5].set_title(\"Best-checkpoint score\")\n",
" axes[5].set_xlabel(\"Epoch\")\n",
" axes[5].legend()\n",
" axes[5].grid(True, alpha=0.3)\n",
"\n",
" fig.tight_layout()\n",
" fig.savefig(output_path, dpi=150)\n",
" if show:\n",
" plt.show()\n",
" plt.close(fig)\n",
" return output_path\n",
"\n",
"\n",
"def analyze_training(trainer):\n",
" return {\n",
" \"best_val_loss\": trainer.best_val_loss,\n",
" \"g_losses\": trainer.g_losses,\n",
" \"d_losses\": trainer.d_losses,\n",
" \"l1_losses\": trainer.l1_losses,\n",
" \"ssim_losses\": trainer.ssim_losses,\n",
" \"edge_losses\": trainer.edge_losses,\n",
" \"val_g_losses\": trainer.val_g_losses,\n",
" \"val_d_losses\": trainer.val_d_losses,\n",
" \"val_l1_losses\": trainer.val_l1_losses,\n",
" \"val_ssim_losses\": trainer.val_ssim_losses,\n",
" \"val_edge_losses\": trainer.val_edge_losses,\n",
" \"val_reconstruction_losses\": trainer.val_reconstruction_losses,\n",
" }\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": "## Main Pipeline\n\nExecutes the full GAN workflow:\n1. Create config\n2. Build paired data loaders\n3. Initialize Google -> Yandex GAN\n4. Train with validation\n5. Save checkpoints in `runs/checkpoints/`\n6. Show loss plots and generated sample grid\n\nThis block is intentionally top-level, not wrapped in `main()`, so notebook\nvariables such as `model`, `trainer`, `train_loader`, `val_loader`, and\n"
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\"\"\"Executable GAN training pipeline.\n",
"\n",
"The code is intentionally top-level, mirroring the SiaN notebook style:\n",
"when this file is included in the generated notebook, variables remain\n",
"available for debugging and interactive experiments.\n",
"\"\"\"\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"config = create_config()\n",
"device = get_compatible_device(prefer_cuda=config[\"prefer_cuda\"])\n",
"print(f\"Using device: {device}\")\n",
"\n",
"train_loader, val_loader = create_data_loaders(\n",
" root_dir=config[\"data_dir\"],\n",
" batch_size=config[\"batch_size\"],\n",
" train_split=config[\"train_split\"],\n",
" image_size=tuple(config[\"image_size\"]),\n",
" num_workers=config[\"num_workers\"],\n",
")\n",
"print(f\"Data loaders created: train={len(train_loader.dataset)}, val={len(val_loader.dataset)}\")\n",
"\n",
"model = create_gan(\n",
" gan_mode=config[\"gan_mode\"],\n",
" lambda_GAN=config[\"lambda_GAN\"],\n",
" lambda_L1=config[\"lambda_L1\"],\n",
" lambda_SSIM=config[\"lambda_SSIM\"],\n",
" lambda_edge=config[\"lambda_edge\"],\n",
" use_cuda=(device.type == \"cuda\"),\n",
")\n",
"\n",
"generator_params = sum(p.numel() for p in model.generator.parameters())\n",
"discriminator_params = sum(p.numel() for p in model.discriminator.parameters())\n",
"print(f\"Model created: generator={generator_params:,}, discriminator={discriminator_params:,}\")\n",
"\n",
"trainer = create_trainer(model, train_loader, val_loader, config)\n",
"trainer.train(config[\"epochs\"])\n",
"\n",
"training_analysis = analyze_training(trainer)\n",
"images_dir = Path(config[\"output_dir\"]) / \"images\"\n",
"history_plot_path = plot_training_history(trainer, images_dir)\n",
"generation_samples_path = visualize_generation(\n",
" model=model,\n",
" data_loader=val_loader,\n",
" output_dir=images_dir,\n",
" device=device,\n",
" num_samples=config[\"num_visual_samples\"],\n",
")\n",
"\n",
"print(f\"Training history plot: {history_plot_path}\")\n",
"print(f\"Generation samples: {generation_samples_path}\")\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!zip artefacts.zip runs/checkpoints/best.pth runs/images/training_history.png runs/images/generation_samples.png\n",
"\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"name": "python",
"version": "3.11.0"
}
},
"nbformat": 4,
"nbformat_minor": 5
}