from typing import Optional, Tuple import torch import torch.nn as nn import torch.nn.functional as F class UNetDownBlock(nn.Module): """Блок downsampling для U-Net генератора""" def __init__( self, in_channels: int, out_channels: int, normalize: bool = True, dropout: float = 0.0, ): super().__init__() layers = [ nn.Conv2d( in_channels, out_channels, kernel_size=4, stride=2, padding=1, bias=False, ) ] if normalize: layers.append(nn.BatchNorm2d(out_channels)) layers.append(nn.LeakyReLU(0.2, inplace=True)) if dropout > 0: layers.append(nn.Dropout2d(dropout)) self.model = nn.Sequential(*layers) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.model(x) class UNetUpBlock(nn.Module): """Блок upsampling для U-Net генератора""" def __init__(self, in_channels: int, out_channels: int, dropout: float = 0.0): super().__init__() layers = [ nn.ConvTranspose2d( in_channels, out_channels, kernel_size=4, stride=2, padding=1, bias=False, ), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True), ] if dropout > 0: layers.append(nn.Dropout2d(dropout)) self.model = nn.Sequential(*layers) def forward(self, x: torch.Tensor, skip_input: torch.Tensor) -> torch.Tensor: x = self.model(x) # Обрезаем skip connection до размера x, если необходимо if x.shape != skip_input.shape: diffY = skip_input.size(2) - x.size(2) diffX = skip_input.size(3) - x.size(3) x = F.pad( x, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2] ) x = torch.cat([x, skip_input], dim=1) return x class GeneratorUNet(nn.Module): """Генератор на основе U-Net архитектуры для преобразования Yandex → Google""" def __init__(self, in_channels: int = 3, out_channels: int = 3): super().__init__() # Downsampling path self.down1 = UNetDownBlock(in_channels, 64, normalize=False) self.down2 = UNetDownBlock(64, 128) self.down3 = UNetDownBlock(128, 256) self.down4 = UNetDownBlock(256, 512) self.down5 = UNetDownBlock(512, 512) self.down6 = UNetDownBlock(512, 512) self.down7 = UNetDownBlock(512, 512) # Bottleneck self.bottleneck = nn.Sequential( nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1, bias=False), nn.ReLU(inplace=True), ) # Upsampling path self.up1 = UNetUpBlock(512, 512, dropout=0.5) self.up2 = UNetUpBlock(1024, 512, dropout=0.5) self.up3 = UNetUpBlock(1024, 512, dropout=0.5) self.up4 = UNetUpBlock(1024, 512) self.up5 = UNetUpBlock(1024, 256) self.up6 = UNetUpBlock(512, 128) self.up7 = UNetUpBlock(256, 64) # Final layer self.final = nn.Sequential( nn.ConvTranspose2d(128, out_channels, kernel_size=4, stride=2, padding=1), nn.Tanh(), ) def forward(self, x: torch.Tensor) -> torch.Tensor: # Downsampling d1 = self.down1(x) # 350x350 d2 = self.down2(d1) # 175x175 d3 = self.down3(d2) # 88x88 d4 = self.down4(d3) # 44x44 d5 = self.down5(d4) # 22x22 d6 = self.down6(d5) # 11x11 d7 = self.down7(d6) # 6x6 # Bottleneck u = self.bottleneck(d7) # 3x3 # Upsampling with skip connections u = self.up1(u, d7) # 6x6 u = self.up2(u, d6) # 11x11 u = self.up3(u, d5) # 22x22 u = self.up4(u, d4) # 44x44 u = self.up5(u, d3) # 88x88 u = self.up6(u, d2) # 175x175 u = self.up7(u, d1) # 350x350 # Final output return self.final(u) # 700x700 class DiscriminatorPatchGAN(nn.Module): """Дискриминатор PatchGAN для изображений 700x700""" def __init__( self, in_channels: int = 6 ): # 3 для реального + 3 для сгенерированного super().__init__() def discriminator_block( in_filters: int, out_filters: int, normalization: bool = True ): """Блок дискриминатора""" layers = [ nn.Conv2d(in_filters, out_filters, kernel_size=4, stride=2, padding=1) ] if normalization: layers.append(nn.BatchNorm2d(out_filters)) layers.append(nn.LeakyReLU(0.2, inplace=True)) return layers self.model = nn.Sequential( *discriminator_block(in_channels, 64, normalization=False), # 350x350 *discriminator_block(64, 128), # 175x175 *discriminator_block(128, 256), # 88x88 *discriminator_block(256, 512), # 44x44 nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=1), # 41x41 nn.Sigmoid(), ) def forward(self, img_A: torch.Tensor, img_B: torch.Tensor) -> torch.Tensor: """ Принимает пару изображений (реальное и сгенерированное) и возвращает вероятность того, что пара реальная """ # Объединяем два изображения по каналам img_input = torch.cat((img_A, img_B), 1) return self.model(img_input) class GANLoss(nn.Module): """Функция потерь для GAN""" def __init__( self, gan_mode: str = "vanilla", target_real_label: float = 1.0, target_fake_label: float = 0.0, ): super().__init__() self.register_buffer("real_label", torch.tensor(target_real_label)) self.register_buffer("fake_label", torch.tensor(target_fake_label)) self.gan_mode = gan_mode if gan_mode == "vanilla": self.loss = nn.BCEWithLogitsLoss() elif gan_mode == "lsgan": self.loss = nn.MSELoss() elif gan_mode == "wgangp": self.loss = None else: raise NotImplementedError(f"GAN mode {gan_mode} not implemented") def get_target_tensor( self, prediction: torch.Tensor, target_is_real: bool ) -> torch.Tensor: """Создает тензор меток""" if target_is_real: target_tensor = self.real_label else: target_tensor = self.fake_label return target_tensor.expand_as(prediction) def __call__(self, prediction: torch.Tensor, target_is_real: bool) -> torch.Tensor: """Вычисляет потери""" if self.gan_mode in ["vanilla", "lsgan"]: target_tensor = self.get_target_tensor(prediction, target_is_real) loss = self.loss(prediction, target_tensor) elif self.gan_mode == "wgangp": if target_is_real: loss = -prediction.mean() else: loss = prediction.mean() return loss class ImageGAN(nn.Module): """Основной класс GAN для преобразования изображений Yandex → Google""" def __init__( self, input_channels: int = 3, output_channels: int = 3, gan_mode: str = "vanilla", lambda_L1: float = 100.0, use_cuda: bool = True, ): super().__init__() self.generator = GeneratorUNet(input_channels, output_channels) self.discriminator = DiscriminatorPatchGAN(input_channels + output_channels) self.gan_loss = GANLoss(gan_mode) self.l1_loss = nn.L1Loss() self.lambda_L1 = lambda_L1 self.device = torch.device( "cuda" if use_cuda and torch.cuda.is_available() else "cpu" ) self.to(self.device) def forward(self, yandex_image: torch.Tensor) -> torch.Tensor: """Генерация изображения Google из Yandex""" return self.generator(yandex_image) def generator_step( self, yandex_image: torch.Tensor, real_google_image: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Шаг обучения генератора Returns: total_loss: общие потери генератора gan_loss: потери GAN l1_loss: потери L1 """ # Генерируем изображение fake_google_image = self.generator(yandex_image) # Оцениваем дискриминатором fake_pred = self.discriminator(yandex_image, fake_google_image) # Потери GAN (пытаемся обмануть дискриминатор) gan_loss = self.gan_loss(fake_pred, True) # Потери L1 для сохранения структуры l1_loss = self.l1_loss(fake_google_image, real_google_image) * self.lambda_L1 # Общие потери total_loss = gan_loss + l1_loss return total_loss, gan_loss, l1_loss def discriminator_step( self, yandex_image: torch.Tensor, real_google_image: torch.Tensor, fake_google_image: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Шаг обучения дискриминатора Returns: total_loss: общие потери дискриминатора real_loss: потери на реальных изображениях fake_loss: потери на сгенерированных изображениях """ # Предсказания для реальных пар real_pred = self.discriminator(yandex_image, real_google_image) real_loss = self.gan_loss(real_pred, True) # Предсказания для сгенерированных пар fake_pred = self.discriminator(yandex_image, fake_google_image.detach()) fake_loss = self.gan_loss(fake_pred, False) # Общие потери дискриминатора total_loss = (real_loss + fake_loss) * 0.5 return total_loss, real_loss, fake_loss def to(self, device): """Перемещает модель на устройство""" self.generator.to(device) self.discriminator.to(device) return self def train_mode(self): """Переключает модель в режим обучения""" self.generator.train() self.discriminator.train() def eval_mode(self): """Переключает модель в режим оценки""" self.generator.eval() self.discriminator.eval() def save_checkpoint(self, path: str): """Сохраняет чекпоинт модели""" checkpoint = { "generator_state_dict": self.generator.state_dict(), "discriminator_state_dict": self.discriminator.state_dict(), "generator_optimizer_state_dict": getattr( self.generator, "optimizer_state_dict", None ), "discriminator_optimizer_state_dict": getattr( self.discriminator, "optimizer_state_dict", None ), } torch.save(checkpoint, path) def load_checkpoint(self, path: str): """Загружает чекпоинт модели""" checkpoint = torch.load(path, map_location=self.device) self.generator.load_state_dict(checkpoint["generator_state_dict"]) self.discriminator.load_state_dict(checkpoint["discriminator_state_dict"]) if checkpoint["generator_optimizer_state_dict"] is not None: self.generator.optimizer_state_dict = checkpoint[ "generator_optimizer_state_dict" ] if checkpoint["discriminator_optimizer_state_dict"] is not None: self.discriminator.optimizer_state_dict = checkpoint[ "discriminator_optimizer_state_dict" ] def create_image_gan( input_channels: int = 3, output_channels: int = 3, gan_mode: str = "vanilla", lambda_L1: float = 100.0, use_cuda: bool = True, ) -> ImageGAN: """ Создает и возвращает модель GAN для преобразования изображений Args: input_channels: количество входных каналов (обычно 3 для RGB) output_channels: количество выходных каналов (обычно 3 для RGB) gan_mode: режим GAN ('vanilla', 'lsgan', 'wgangp') lambda_L1: вес L1 потерь use_cuda: использовать ли CUDA если доступно Returns: ImageGAN: модель GAN """ return ImageGAN( input_channels=input_channels, output_channels=output_channels, gan_mode=gan_mode, lambda_L1=lambda_L1, use_cuda=use_cuda, ) # Вспомогательные функции для инициализации весов def weights_init_normal(m): """Инициализация весов с нормальным распределением""" classname = m.__class__.__name__ if classname.find("Conv") != -1: nn.init.normal_(m.weight.data, 0.0, 0.02) elif classname.find("BatchNorm") != -1: nn.init.normal_(m.weight.data, 1.0, 0.02) nn.init.constant_(m.batch_norm.bias.data, 0.0) def initialize_gan_weights(generator: nn.Module, discriminator: nn.Module): """Инициализирует веса генератора и дискриминатора""" generator.apply(weights_init_normal) discriminator.apply(weights_init_normal)