394 lines
14 KiB
Python
394 lines
14 KiB
Python
from typing import Optional, Tuple
|
||
|
||
import torch
|
||
import torch.nn as nn
|
||
import torch.nn.functional as F
|
||
|
||
|
||
class UNetDownBlock(nn.Module):
|
||
"""Блок downsampling для U-Net генератора"""
|
||
|
||
def __init__(
|
||
self,
|
||
in_channels: int,
|
||
out_channels: int,
|
||
normalize: bool = True,
|
||
dropout: float = 0.0,
|
||
):
|
||
super().__init__()
|
||
layers = [
|
||
nn.Conv2d(
|
||
in_channels,
|
||
out_channels,
|
||
kernel_size=4,
|
||
stride=2,
|
||
padding=1,
|
||
bias=False,
|
||
)
|
||
]
|
||
if normalize:
|
||
layers.append(nn.BatchNorm2d(out_channels))
|
||
layers.append(nn.LeakyReLU(0.2, inplace=True))
|
||
if dropout > 0:
|
||
layers.append(nn.Dropout2d(dropout))
|
||
self.model = nn.Sequential(*layers)
|
||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||
return self.model(x)
|
||
|
||
|
||
class UNetUpBlock(nn.Module):
|
||
"""Блок upsampling для U-Net генератора"""
|
||
|
||
def __init__(self, in_channels: int, out_channels: int, dropout: float = 0.0):
|
||
super().__init__()
|
||
layers = [
|
||
nn.ConvTranspose2d(
|
||
in_channels,
|
||
out_channels,
|
||
kernel_size=4,
|
||
stride=2,
|
||
padding=1,
|
||
bias=False,
|
||
),
|
||
nn.BatchNorm2d(out_channels),
|
||
nn.ReLU(inplace=True),
|
||
]
|
||
if dropout > 0:
|
||
layers.append(nn.Dropout2d(dropout))
|
||
self.model = nn.Sequential(*layers)
|
||
|
||
def forward(self, x: torch.Tensor, skip_input: torch.Tensor) -> torch.Tensor:
|
||
x = self.model(x)
|
||
# Обрезаем skip connection до размера x, если необходимо
|
||
if x.shape != skip_input.shape:
|
||
diffY = skip_input.size(2) - x.size(2)
|
||
diffX = skip_input.size(3) - x.size(3)
|
||
x = F.pad(
|
||
x, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2]
|
||
)
|
||
x = torch.cat([x, skip_input], dim=1)
|
||
return x
|
||
|
||
|
||
class GeneratorUNet(nn.Module):
|
||
"""Генератор на основе U-Net архитектуры для преобразования Yandex → Google"""
|
||
|
||
def __init__(self, in_channels: int = 3, out_channels: int = 3):
|
||
super().__init__()
|
||
|
||
# Downsampling path
|
||
self.down1 = UNetDownBlock(in_channels, 64, normalize=False)
|
||
self.down2 = UNetDownBlock(64, 128)
|
||
self.down3 = UNetDownBlock(128, 256)
|
||
self.down4 = UNetDownBlock(256, 512)
|
||
self.down5 = UNetDownBlock(512, 512)
|
||
self.down6 = UNetDownBlock(512, 512)
|
||
self.down7 = UNetDownBlock(512, 512)
|
||
|
||
# Bottleneck
|
||
self.bottleneck = nn.Sequential(
|
||
nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1, bias=False),
|
||
nn.ReLU(inplace=True),
|
||
)
|
||
|
||
# Upsampling path
|
||
self.up1 = UNetUpBlock(512, 512, dropout=0.5)
|
||
self.up2 = UNetUpBlock(1024, 512, dropout=0.5)
|
||
self.up3 = UNetUpBlock(1024, 512, dropout=0.5)
|
||
self.up4 = UNetUpBlock(1024, 512)
|
||
self.up5 = UNetUpBlock(1024, 256)
|
||
self.up6 = UNetUpBlock(512, 128)
|
||
self.up7 = UNetUpBlock(256, 64)
|
||
|
||
# Final layer
|
||
self.final = nn.Sequential(
|
||
nn.ConvTranspose2d(128, out_channels, kernel_size=4, stride=2, padding=1),
|
||
nn.Tanh(),
|
||
)
|
||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||
# Downsampling
|
||
d1 = self.down1(x) # 350x350
|
||
d2 = self.down2(d1) # 175x175
|
||
d3 = self.down3(d2) # 88x88
|
||
d4 = self.down4(d3) # 44x44
|
||
d5 = self.down5(d4) # 22x22
|
||
d6 = self.down6(d5) # 11x11
|
||
d7 = self.down7(d6) # 6x6
|
||
|
||
# Bottleneck
|
||
u = self.bottleneck(d7) # 3x3
|
||
|
||
# Upsampling with skip connections
|
||
u = self.up1(u, d7) # 6x6
|
||
u = self.up2(u, d6) # 11x11
|
||
u = self.up3(u, d5) # 22x22
|
||
u = self.up4(u, d4) # 44x44
|
||
u = self.up5(u, d3) # 88x88
|
||
u = self.up6(u, d2) # 175x175
|
||
u = self.up7(u, d1) # 350x350
|
||
|
||
# Final output
|
||
return self.final(u) # 700x700
|
||
|
||
|
||
class DiscriminatorPatchGAN(nn.Module):
|
||
"""Дискриминатор PatchGAN для изображений 700x700"""
|
||
|
||
def __init__(
|
||
self, in_channels: int = 6
|
||
): # 3 для реального + 3 для сгенерированного
|
||
super().__init__()
|
||
|
||
def discriminator_block(
|
||
in_filters: int, out_filters: int, normalization: bool = True
|
||
):
|
||
"""Блок дискриминатора"""
|
||
layers = [
|
||
nn.Conv2d(in_filters, out_filters, kernel_size=4, stride=2, padding=1)
|
||
]
|
||
if normalization:
|
||
layers.append(nn.BatchNorm2d(out_filters))
|
||
layers.append(nn.LeakyReLU(0.2, inplace=True))
|
||
return layers
|
||
|
||
self.model = nn.Sequential(
|
||
*discriminator_block(in_channels, 64, normalization=False), # 350x350
|
||
*discriminator_block(64, 128), # 175x175
|
||
*discriminator_block(128, 256), # 88x88
|
||
*discriminator_block(256, 512), # 44x44
|
||
nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=1), # 41x41
|
||
nn.Sigmoid(),
|
||
)
|
||
|
||
def forward(self, img_A: torch.Tensor, img_B: torch.Tensor) -> torch.Tensor:
|
||
"""
|
||
Принимает пару изображений (реальное и сгенерированное)
|
||
и возвращает вероятность того, что пара реальная
|
||
"""
|
||
# Объединяем два изображения по каналам
|
||
img_input = torch.cat((img_A, img_B), 1)
|
||
return self.model(img_input)
|
||
|
||
|
||
class GANLoss(nn.Module):
|
||
"""Функция потерь для GAN"""
|
||
|
||
def __init__(
|
||
self,
|
||
gan_mode: str = "vanilla",
|
||
target_real_label: float = 1.0,
|
||
target_fake_label: float = 0.0,
|
||
):
|
||
super().__init__()
|
||
self.register_buffer("real_label", torch.tensor(target_real_label))
|
||
self.register_buffer("fake_label", torch.tensor(target_fake_label))
|
||
self.gan_mode = gan_mode
|
||
|
||
if gan_mode == "vanilla":
|
||
self.loss = nn.BCEWithLogitsLoss()
|
||
elif gan_mode == "lsgan":
|
||
self.loss = nn.MSELoss()
|
||
elif gan_mode == "wgangp":
|
||
self.loss = None
|
||
else:
|
||
raise NotImplementedError(f"GAN mode {gan_mode} not implemented")
|
||
|
||
def get_target_tensor(
|
||
self, prediction: torch.Tensor, target_is_real: bool
|
||
) -> torch.Tensor:
|
||
"""Создает тензор меток"""
|
||
if target_is_real:
|
||
target_tensor = self.real_label
|
||
else:
|
||
target_tensor = self.fake_label
|
||
return target_tensor.expand_as(prediction)
|
||
|
||
def __call__(self, prediction: torch.Tensor, target_is_real: bool) -> torch.Tensor:
|
||
"""Вычисляет потери"""
|
||
if self.gan_mode in ["vanilla", "lsgan"]:
|
||
target_tensor = self.get_target_tensor(prediction, target_is_real)
|
||
loss = self.loss(prediction, target_tensor)
|
||
elif self.gan_mode == "wgangp":
|
||
if target_is_real:
|
||
loss = -prediction.mean()
|
||
else:
|
||
loss = prediction.mean()
|
||
return loss
|
||
|
||
|
||
class ImageGAN(nn.Module):
|
||
"""Основной класс GAN для преобразования изображений Yandex → Google"""
|
||
|
||
def __init__(
|
||
self,
|
||
input_channels: int = 3,
|
||
output_channels: int = 3,
|
||
gan_mode: str = "vanilla",
|
||
lambda_L1: float = 100.0,
|
||
use_cuda: bool = True,
|
||
):
|
||
super().__init__()
|
||
|
||
self.generator = GeneratorUNet(input_channels, output_channels)
|
||
self.discriminator = DiscriminatorPatchGAN(input_channels + output_channels)
|
||
self.gan_loss = GANLoss(gan_mode)
|
||
self.l1_loss = nn.L1Loss()
|
||
self.lambda_L1 = lambda_L1
|
||
|
||
self.device = torch.device(
|
||
"cuda" if use_cuda and torch.cuda.is_available() else "cpu"
|
||
)
|
||
self.to(self.device)
|
||
|
||
def forward(self, yandex_image: torch.Tensor) -> torch.Tensor:
|
||
"""Генерация изображения Google из Yandex"""
|
||
return self.generator(yandex_image)
|
||
|
||
def generator_step(
|
||
self, yandex_image: torch.Tensor, real_google_image: torch.Tensor
|
||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||
"""
|
||
Шаг обучения генератора
|
||
|
||
Returns:
|
||
total_loss: общие потери генератора
|
||
gan_loss: потери GAN
|
||
l1_loss: потери L1
|
||
"""
|
||
# Генерируем изображение
|
||
fake_google_image = self.generator(yandex_image)
|
||
|
||
# Оцениваем дискриминатором
|
||
fake_pred = self.discriminator(yandex_image, fake_google_image)
|
||
|
||
# Потери GAN (пытаемся обмануть дискриминатор)
|
||
gan_loss = self.gan_loss(fake_pred, True)
|
||
|
||
# Потери L1 для сохранения структуры
|
||
l1_loss = self.l1_loss(fake_google_image, real_google_image) * self.lambda_L1
|
||
|
||
# Общие потери
|
||
total_loss = gan_loss + l1_loss
|
||
|
||
return total_loss, gan_loss, l1_loss
|
||
|
||
def discriminator_step(
|
||
self,
|
||
yandex_image: torch.Tensor,
|
||
real_google_image: torch.Tensor,
|
||
fake_google_image: torch.Tensor,
|
||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||
"""
|
||
Шаг обучения дискриминатора
|
||
|
||
Returns:
|
||
total_loss: общие потери дискриминатора
|
||
real_loss: потери на реальных изображениях
|
||
fake_loss: потери на сгенерированных изображениях
|
||
"""
|
||
# Предсказания для реальных пар
|
||
real_pred = self.discriminator(yandex_image, real_google_image)
|
||
real_loss = self.gan_loss(real_pred, True)
|
||
|
||
# Предсказания для сгенерированных пар
|
||
fake_pred = self.discriminator(yandex_image, fake_google_image.detach())
|
||
fake_loss = self.gan_loss(fake_pred, False)
|
||
|
||
# Общие потери дискриминатора
|
||
total_loss = (real_loss + fake_loss) * 0.5
|
||
|
||
return total_loss, real_loss, fake_loss
|
||
|
||
def to(self, device):
|
||
"""Перемещает модель на устройство"""
|
||
self.generator.to(device)
|
||
self.discriminator.to(device)
|
||
return self
|
||
|
||
def train_mode(self):
|
||
"""Переключает модель в режим обучения"""
|
||
self.generator.train()
|
||
self.discriminator.train()
|
||
|
||
def eval_mode(self):
|
||
"""Переключает модель в режим оценки"""
|
||
self.generator.eval()
|
||
self.discriminator.eval()
|
||
|
||
def save_checkpoint(self, path: str):
|
||
"""Сохраняет чекпоинт модели"""
|
||
checkpoint = {
|
||
"generator_state_dict": self.generator.state_dict(),
|
||
"discriminator_state_dict": self.discriminator.state_dict(),
|
||
"generator_optimizer_state_dict": getattr(
|
||
self.generator, "optimizer_state_dict", None
|
||
),
|
||
"discriminator_optimizer_state_dict": getattr(
|
||
self.discriminator, "optimizer_state_dict", None
|
||
),
|
||
}
|
||
torch.save(checkpoint, path)
|
||
|
||
def load_checkpoint(self, path: str):
|
||
"""Загружает чекпоинт модели"""
|
||
checkpoint = torch.load(path, map_location=self.device)
|
||
self.generator.load_state_dict(checkpoint["generator_state_dict"])
|
||
self.discriminator.load_state_dict(checkpoint["discriminator_state_dict"])
|
||
|
||
if checkpoint["generator_optimizer_state_dict"] is not None:
|
||
self.generator.optimizer_state_dict = checkpoint[
|
||
"generator_optimizer_state_dict"
|
||
]
|
||
if checkpoint["discriminator_optimizer_state_dict"] is not None:
|
||
self.discriminator.optimizer_state_dict = checkpoint[
|
||
"discriminator_optimizer_state_dict"
|
||
]
|
||
|
||
|
||
def create_image_gan(
|
||
input_channels: int = 3,
|
||
output_channels: int = 3,
|
||
gan_mode: str = "vanilla",
|
||
lambda_L1: float = 100.0,
|
||
use_cuda: bool = True,
|
||
) -> ImageGAN:
|
||
"""
|
||
Создает и возвращает модель GAN для преобразования изображений
|
||
|
||
Args:
|
||
input_channels: количество входных каналов (обычно 3 для RGB)
|
||
output_channels: количество выходных каналов (обычно 3 для RGB)
|
||
gan_mode: режим GAN ('vanilla', 'lsgan', 'wgangp')
|
||
lambda_L1: вес L1 потерь
|
||
use_cuda: использовать ли CUDA если доступно
|
||
|
||
Returns:
|
||
ImageGAN: модель GAN
|
||
"""
|
||
return ImageGAN(
|
||
input_channels=input_channels,
|
||
output_channels=output_channels,
|
||
gan_mode=gan_mode,
|
||
lambda_L1=lambda_L1,
|
||
use_cuda=use_cuda,
|
||
)
|
||
|
||
|
||
# Вспомогательные функции для инициализации весов
|
||
def weights_init_normal(m):
|
||
"""Инициализация весов с нормальным распределением"""
|
||
classname = m.__class__.__name__
|
||
if classname.find("Conv") != -1:
|
||
nn.init.normal_(m.weight.data, 0.0, 0.02)
|
||
elif classname.find("BatchNorm") != -1:
|
||
nn.init.normal_(m.weight.data, 1.0, 0.02)
|
||
nn.init.constant_(m.batch_norm.bias.data, 0.0)
|
||
|
||
|
||
def initialize_gan_weights(generator: nn.Module, discriminator: nn.Module):
|
||
"""Инициализирует веса генератора и дискриминатора"""
|
||
generator.apply(weights_init_normal)
|
||
discriminator.apply(weights_init_normal)
|