Files
autopilot/models/GAN/gan.py
2026-02-20 16:52:02 +03:00

394 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
class UNetDownBlock(nn.Module):
"""Блок downsampling для U-Net генератора"""
def __init__(
self,
in_channels: int,
out_channels: int,
normalize: bool = True,
dropout: float = 0.0,
):
super().__init__()
layers = [
nn.Conv2d(
in_channels,
out_channels,
kernel_size=4,
stride=2,
padding=1,
bias=False,
)
]
if normalize:
layers.append(nn.BatchNorm2d(out_channels))
layers.append(nn.LeakyReLU(0.2, inplace=True))
if dropout > 0:
layers.append(nn.Dropout2d(dropout))
self.model = nn.Sequential(*layers)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.model(x)
class UNetUpBlock(nn.Module):
"""Блок upsampling для U-Net генератора"""
def __init__(self, in_channels: int, out_channels: int, dropout: float = 0.0):
super().__init__()
layers = [
nn.ConvTranspose2d(
in_channels,
out_channels,
kernel_size=4,
stride=2,
padding=1,
bias=False,
),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
]
if dropout > 0:
layers.append(nn.Dropout2d(dropout))
self.model = nn.Sequential(*layers)
def forward(self, x: torch.Tensor, skip_input: torch.Tensor) -> torch.Tensor:
x = self.model(x)
# Обрезаем skip connection до размера x, если необходимо
if x.shape != skip_input.shape:
diffY = skip_input.size(2) - x.size(2)
diffX = skip_input.size(3) - x.size(3)
x = F.pad(
x, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2]
)
x = torch.cat([x, skip_input], dim=1)
return x
class GeneratorUNet(nn.Module):
"""Генератор на основе U-Net архитектуры для преобразования Yandex → Google"""
def __init__(self, in_channels: int = 3, out_channels: int = 3):
super().__init__()
# Downsampling path
self.down1 = UNetDownBlock(in_channels, 64, normalize=False)
self.down2 = UNetDownBlock(64, 128)
self.down3 = UNetDownBlock(128, 256)
self.down4 = UNetDownBlock(256, 512)
self.down5 = UNetDownBlock(512, 512)
self.down6 = UNetDownBlock(512, 512)
self.down7 = UNetDownBlock(512, 512)
# Bottleneck
self.bottleneck = nn.Sequential(
nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1, bias=False),
nn.ReLU(inplace=True),
)
# Upsampling path
self.up1 = UNetUpBlock(512, 512, dropout=0.5)
self.up2 = UNetUpBlock(1024, 512, dropout=0.5)
self.up3 = UNetUpBlock(1024, 512, dropout=0.5)
self.up4 = UNetUpBlock(1024, 512)
self.up5 = UNetUpBlock(1024, 256)
self.up6 = UNetUpBlock(512, 128)
self.up7 = UNetUpBlock(256, 64)
# Final layer
self.final = nn.Sequential(
nn.ConvTranspose2d(128, out_channels, kernel_size=4, stride=2, padding=1),
nn.Tanh(),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# Downsampling
d1 = self.down1(x) # 350x350
d2 = self.down2(d1) # 175x175
d3 = self.down3(d2) # 88x88
d4 = self.down4(d3) # 44x44
d5 = self.down5(d4) # 22x22
d6 = self.down6(d5) # 11x11
d7 = self.down7(d6) # 6x6
# Bottleneck
u = self.bottleneck(d7) # 3x3
# Upsampling with skip connections
u = self.up1(u, d7) # 6x6
u = self.up2(u, d6) # 11x11
u = self.up3(u, d5) # 22x22
u = self.up4(u, d4) # 44x44
u = self.up5(u, d3) # 88x88
u = self.up6(u, d2) # 175x175
u = self.up7(u, d1) # 350x350
# Final output
return self.final(u) # 700x700
class DiscriminatorPatchGAN(nn.Module):
"""Дискриминатор PatchGAN для изображений 700x700"""
def __init__(
self, in_channels: int = 6
): # 3 для реального + 3 для сгенерированного
super().__init__()
def discriminator_block(
in_filters: int, out_filters: int, normalization: bool = True
):
"""Блок дискриминатора"""
layers = [
nn.Conv2d(in_filters, out_filters, kernel_size=4, stride=2, padding=1)
]
if normalization:
layers.append(nn.BatchNorm2d(out_filters))
layers.append(nn.LeakyReLU(0.2, inplace=True))
return layers
self.model = nn.Sequential(
*discriminator_block(in_channels, 64, normalization=False), # 350x350
*discriminator_block(64, 128), # 175x175
*discriminator_block(128, 256), # 88x88
*discriminator_block(256, 512), # 44x44
nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=1), # 41x41
nn.Sigmoid(),
)
def forward(self, img_A: torch.Tensor, img_B: torch.Tensor) -> torch.Tensor:
"""
Принимает пару изображений (реальное и сгенерированное)
и возвращает вероятность того, что пара реальная
"""
# Объединяем два изображения по каналам
img_input = torch.cat((img_A, img_B), 1)
return self.model(img_input)
class GANLoss(nn.Module):
"""Функция потерь для GAN"""
def __init__(
self,
gan_mode: str = "vanilla",
target_real_label: float = 1.0,
target_fake_label: float = 0.0,
):
super().__init__()
self.register_buffer("real_label", torch.tensor(target_real_label))
self.register_buffer("fake_label", torch.tensor(target_fake_label))
self.gan_mode = gan_mode
if gan_mode == "vanilla":
self.loss = nn.BCEWithLogitsLoss()
elif gan_mode == "lsgan":
self.loss = nn.MSELoss()
elif gan_mode == "wgangp":
self.loss = None
else:
raise NotImplementedError(f"GAN mode {gan_mode} not implemented")
def get_target_tensor(
self, prediction: torch.Tensor, target_is_real: bool
) -> torch.Tensor:
"""Создает тензор меток"""
if target_is_real:
target_tensor = self.real_label
else:
target_tensor = self.fake_label
return target_tensor.expand_as(prediction)
def __call__(self, prediction: torch.Tensor, target_is_real: bool) -> torch.Tensor:
"""Вычисляет потери"""
if self.gan_mode in ["vanilla", "lsgan"]:
target_tensor = self.get_target_tensor(prediction, target_is_real)
loss = self.loss(prediction, target_tensor)
elif self.gan_mode == "wgangp":
if target_is_real:
loss = -prediction.mean()
else:
loss = prediction.mean()
return loss
class ImageGAN(nn.Module):
"""Основной класс GAN для преобразования изображений Yandex → Google"""
def __init__(
self,
input_channels: int = 3,
output_channels: int = 3,
gan_mode: str = "vanilla",
lambda_L1: float = 100.0,
use_cuda: bool = True,
):
super().__init__()
self.generator = GeneratorUNet(input_channels, output_channels)
self.discriminator = DiscriminatorPatchGAN(input_channels + output_channels)
self.gan_loss = GANLoss(gan_mode)
self.l1_loss = nn.L1Loss()
self.lambda_L1 = lambda_L1
self.device = torch.device(
"cuda" if use_cuda and torch.cuda.is_available() else "cpu"
)
self.to(self.device)
def forward(self, yandex_image: torch.Tensor) -> torch.Tensor:
"""Генерация изображения Google из Yandex"""
return self.generator(yandex_image)
def generator_step(
self, yandex_image: torch.Tensor, real_google_image: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Шаг обучения генератора
Returns:
total_loss: общие потери генератора
gan_loss: потери GAN
l1_loss: потери L1
"""
# Генерируем изображение
fake_google_image = self.generator(yandex_image)
# Оцениваем дискриминатором
fake_pred = self.discriminator(yandex_image, fake_google_image)
# Потери GAN (пытаемся обмануть дискриминатор)
gan_loss = self.gan_loss(fake_pred, True)
# Потери L1 для сохранения структуры
l1_loss = self.l1_loss(fake_google_image, real_google_image) * self.lambda_L1
# Общие потери
total_loss = gan_loss + l1_loss
return total_loss, gan_loss, l1_loss
def discriminator_step(
self,
yandex_image: torch.Tensor,
real_google_image: torch.Tensor,
fake_google_image: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Шаг обучения дискриминатора
Returns:
total_loss: общие потери дискриминатора
real_loss: потери на реальных изображениях
fake_loss: потери на сгенерированных изображениях
"""
# Предсказания для реальных пар
real_pred = self.discriminator(yandex_image, real_google_image)
real_loss = self.gan_loss(real_pred, True)
# Предсказания для сгенерированных пар
fake_pred = self.discriminator(yandex_image, fake_google_image.detach())
fake_loss = self.gan_loss(fake_pred, False)
# Общие потери дискриминатора
total_loss = (real_loss + fake_loss) * 0.5
return total_loss, real_loss, fake_loss
def to(self, device):
"""Перемещает модель на устройство"""
self.generator.to(device)
self.discriminator.to(device)
return self
def train_mode(self):
"""Переключает модель в режим обучения"""
self.generator.train()
self.discriminator.train()
def eval_mode(self):
"""Переключает модель в режим оценки"""
self.generator.eval()
self.discriminator.eval()
def save_checkpoint(self, path: str):
"""Сохраняет чекпоинт модели"""
checkpoint = {
"generator_state_dict": self.generator.state_dict(),
"discriminator_state_dict": self.discriminator.state_dict(),
"generator_optimizer_state_dict": getattr(
self.generator, "optimizer_state_dict", None
),
"discriminator_optimizer_state_dict": getattr(
self.discriminator, "optimizer_state_dict", None
),
}
torch.save(checkpoint, path)
def load_checkpoint(self, path: str):
"""Загружает чекпоинт модели"""
checkpoint = torch.load(path, map_location=self.device)
self.generator.load_state_dict(checkpoint["generator_state_dict"])
self.discriminator.load_state_dict(checkpoint["discriminator_state_dict"])
if checkpoint["generator_optimizer_state_dict"] is not None:
self.generator.optimizer_state_dict = checkpoint[
"generator_optimizer_state_dict"
]
if checkpoint["discriminator_optimizer_state_dict"] is not None:
self.discriminator.optimizer_state_dict = checkpoint[
"discriminator_optimizer_state_dict"
]
def create_image_gan(
input_channels: int = 3,
output_channels: int = 3,
gan_mode: str = "vanilla",
lambda_L1: float = 100.0,
use_cuda: bool = True,
) -> ImageGAN:
"""
Создает и возвращает модель GAN для преобразования изображений
Args:
input_channels: количество входных каналов (обычно 3 для RGB)
output_channels: количество выходных каналов (обычно 3 для RGB)
gan_mode: режим GAN ('vanilla', 'lsgan', 'wgangp')
lambda_L1: вес L1 потерь
use_cuda: использовать ли CUDA если доступно
Returns:
ImageGAN: модель GAN
"""
return ImageGAN(
input_channels=input_channels,
output_channels=output_channels,
gan_mode=gan_mode,
lambda_L1=lambda_L1,
use_cuda=use_cuda,
)
# Вспомогательные функции для инициализации весов
def weights_init_normal(m):
"""Инициализация весов с нормальным распределением"""
classname = m.__class__.__name__
if classname.find("Conv") != -1:
nn.init.normal_(m.weight.data, 0.0, 0.02)
elif classname.find("BatchNorm") != -1:
nn.init.normal_(m.weight.data, 1.0, 0.02)
nn.init.constant_(m.batch_norm.bias.data, 0.0)
def initialize_gan_weights(generator: nn.Module, discriminator: nn.Module):
"""Инициализирует веса генератора и дискриминатора"""
generator.apply(weights_init_normal)
discriminator.apply(weights_init_normal)