256 lines
8.7 KiB
Python
256 lines
8.7 KiB
Python
"""GAN model for image translation Yandex -> Google."""
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
|
|
class UNetDownBlock(nn.Module):
|
|
"""Downsampling block for U-Net."""
|
|
|
|
def __init__(self, in_channels: int, out_channels: int, normalize: bool = True, dropout: float = 0.0):
|
|
super().__init__()
|
|
layers = [
|
|
nn.Conv2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1, bias=False)
|
|
]
|
|
if normalize:
|
|
layers.append(nn.BatchNorm2d(out_channels))
|
|
layers.append(nn.LeakyReLU(0.2, inplace=True))
|
|
if dropout > 0:
|
|
layers.append(nn.Dropout2d(dropout))
|
|
self.model = nn.Sequential(*layers)
|
|
|
|
def forward(self, x):
|
|
return self.model(x)
|
|
|
|
|
|
class UNetUpBlock(nn.Module):
|
|
"""Upsampling block for U-Net."""
|
|
|
|
def __init__(self, in_channels: int, out_channels: int, dropout: float = 0.0):
|
|
super().__init__()
|
|
self.upconv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1, bias=False)
|
|
self.norm = nn.BatchNorm2d(out_channels)
|
|
self.relu = nn.ReLU(inplace=True)
|
|
if dropout > 0:
|
|
self.dropout = nn.Dropout2d(dropout)
|
|
else:
|
|
self.dropout = None
|
|
|
|
def forward(self, x, skip_input):
|
|
print(x.shape)
|
|
print(skip_input.shape)
|
|
x = self.upconv(x)
|
|
# Pad if needed to match skip connection size
|
|
if x.shape != skip_input.shape:
|
|
diff_h = skip_input.size(2) - x.size(2)
|
|
diff_w = skip_input.size(3) - x.size(3)
|
|
x = F.pad(x, [diff_w // 2, diff_w - diff_w // 2, diff_h // 2, diff_h - diff_h // 2])
|
|
x = self.norm(x)
|
|
x = self.relu(x)
|
|
if self.dropout:
|
|
x = self.dropout(x)
|
|
x = torch.cat([x, skip_input], dim=1)
|
|
return x
|
|
|
|
|
|
class GeneratorUNet(nn.Module):
|
|
"""U-Net generator for Yandex -> Google translation."""
|
|
|
|
def __init__(self, in_channels: int = 3, out_channels: int = 3):
|
|
super().__init__()
|
|
|
|
# Downsampling
|
|
self.down1 = UNetDownBlock(in_channels, 64, normalize=False)
|
|
self.down2 = UNetDownBlock(64, 128)
|
|
self.down3 = UNetDownBlock(128, 256)
|
|
self.down4 = UNetDownBlock(256, 512)
|
|
self.down5 = UNetDownBlock(512, 512)
|
|
self.down6 = UNetDownBlock(512, 512)
|
|
self.down7 = UNetDownBlock(512, 512)
|
|
|
|
# Bottleneck
|
|
self.bottleneck = nn.Sequential(
|
|
nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1, bias=False),
|
|
nn.ReLU(inplace=True),
|
|
)
|
|
|
|
# Upsampling
|
|
self.up1 = UNetUpBlock(1024, 512, dropout=0.5)
|
|
self.up2 = UNetUpBlock(1024, 512, dropout=0.5)
|
|
self.up3 = UNetUpBlock(1024, 512, dropout=0.5)
|
|
self.up4 = UNetUpBlock(1024, 512)
|
|
self.up5 = UNetUpBlock(1024, 256)
|
|
self.up6 = UNetUpBlock(512, 128)
|
|
self.up7 = UNetUpBlock(256, 64)
|
|
|
|
# Final
|
|
self.final = nn.Sequential(
|
|
nn.ConvTranspose2d(128, out_channels, kernel_size=4, stride=2, padding=1),
|
|
nn.Tanh(),
|
|
)
|
|
|
|
def forward(self, x):
|
|
# Down
|
|
d1 = self.down1(x)
|
|
d2 = self.down2(d1)
|
|
d3 = self.down3(d2)
|
|
d4 = self.down4(d3)
|
|
d5 = self.down5(d4)
|
|
d6 = self.down6(d5)
|
|
d7 = self.down7(d6)
|
|
|
|
# Bottleneck
|
|
u = self.bottleneck(d7)
|
|
|
|
# Up with skip connections
|
|
u = self.up1(u, d7)
|
|
u = self.up2(u, d6)
|
|
u = self.up3(u, d5)
|
|
u = self.up4(u, d4)
|
|
u = self.up5(u, d3)
|
|
u = self.up6(u, d2)
|
|
u = self.up7(u, d1)
|
|
|
|
return self.final(u)
|
|
|
|
|
|
class DiscriminatorPatchGAN(nn.Module):
|
|
"""PatchGAN discriminator."""
|
|
|
|
def __init__(self, in_channels: int = 6):
|
|
super().__init__()
|
|
self.model = nn.Sequential(
|
|
nn.Conv2d(in_channels, 64, kernel_size=4, stride=2, padding=1),
|
|
nn.LeakyReLU(0.2, inplace=True),
|
|
nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
|
|
nn.BatchNorm2d(128),
|
|
nn.LeakyReLU(0.2, inplace=True),
|
|
nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
|
|
nn.BatchNorm2d(256),
|
|
nn.LeakyReLU(0.2, inplace=True),
|
|
nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1),
|
|
nn.BatchNorm2d(512),
|
|
nn.LeakyReLU(0.2, inplace=True),
|
|
nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=1),
|
|
nn.Sigmoid(),
|
|
)
|
|
|
|
def forward(self, img_A, img_B):
|
|
x = torch.cat([img_A, img_B], dim=1)
|
|
return self.model(x)
|
|
|
|
|
|
class GANLoss(nn.Module):
|
|
"""GAN loss supporting different GAN modes."""
|
|
|
|
def __init__(self, gan_mode: str = "vanilla", target_real: float = 1.0, target_fake: float = 0.0):
|
|
super().__init__()
|
|
self.gan_mode = gan_mode
|
|
self.register_buffer("real_label", torch.tensor(target_real))
|
|
self.register_buffer("fake_label", torch.tensor(target_fake))
|
|
|
|
if gan_mode == "vanilla":
|
|
self.loss_fn = nn.BCEWithLogitsLoss()
|
|
elif gan_mode == "lsgan":
|
|
self.loss_fn = nn.MSELoss()
|
|
elif gan_mode == "wgangp":
|
|
self.loss_fn = None
|
|
else:
|
|
raise ValueError(f"Unknown GAN mode: {gan_mode}")
|
|
|
|
def forward(self, prediction: torch.Tensor, target_is_real: bool) -> torch.Tensor:
|
|
if self.gan_mode in ["vanilla", "lsgan"]:
|
|
target = self.real_label if target_is_real else self.fake_label
|
|
target = target.expand_as(prediction)
|
|
return self.loss_fn(prediction, target)
|
|
elif self.gan_mode == "wgangp":
|
|
return -prediction.mean() if target_is_real else prediction.mean()
|
|
|
|
|
|
class ImageGAN(nn.Module):
|
|
"""Complete GAN model for image translation."""
|
|
|
|
def __init__(
|
|
self,
|
|
input_channels: int = 3,
|
|
output_channels: int = 3,
|
|
gan_mode: str = "vanilla",
|
|
lambda_L1: float = 100.0,
|
|
use_cuda: bool = True,
|
|
):
|
|
super().__init__()
|
|
self.generator = GeneratorUNet(input_channels, output_channels)
|
|
self.discriminator = DiscriminatorPatchGAN(input_channels + output_channels)
|
|
self.gan_loss = GANLoss(gan_mode)
|
|
self.l1_loss = nn.L1Loss()
|
|
self.lambda_L1 = lambda_L1
|
|
|
|
self.device = torch.device("cuda" if use_cuda and torch.cuda.is_available() else "cpu")
|
|
self.to(self.device)
|
|
|
|
def forward(self, yandex_image):
|
|
"""Generate Google image from Yandex."""
|
|
return self.generator(yandex_image)
|
|
|
|
def generator_step(self, yandex_img, real_google_img):
|
|
"""Compute generator losses."""
|
|
fake_google = self.generator(yandex_img)
|
|
fake_pred = self.discriminator(yandex_img, fake_google)
|
|
gan_loss = self.gan_loss(fake_pred, True)
|
|
l1_loss = self.l1_loss(fake_google, real_google_img) * self.lambda_L1
|
|
total_loss = gan_loss + l1_loss
|
|
return total_loss, gan_loss, l1_loss
|
|
|
|
def discriminator_step(self, yandex_img, real_google_img, fake_google_img):
|
|
"""Compute discriminator losses."""
|
|
real_pred = self.discriminator(yandex_img, real_google_img)
|
|
real_loss = self.gan_loss(real_pred, True)
|
|
fake_pred = self.discriminator(yandex_img, fake_google_img.detach())
|
|
fake_loss = self.gan_loss(fake_pred, False)
|
|
total_loss = (real_loss + fake_loss) * 0.5
|
|
return total_loss, real_loss, fake_loss
|
|
|
|
|
|
def create_gan(
|
|
input_channels: int = 3,
|
|
output_channels: int = 3,
|
|
gan_mode: str = "vanilla",
|
|
lambda_L1: float = 100.0,
|
|
use_cuda: bool = True,
|
|
) -> ImageGAN:
|
|
"""Create a GAN model."""
|
|
return ImageGAN(
|
|
input_channels=input_channels,
|
|
output_channels=output_channels,
|
|
gan_mode=gan_mode,
|
|
lambda_L1=lambda_L1,
|
|
use_cuda=use_cuda,
|
|
)
|
|
|
|
|
|
def initialize_weights(model: nn.Module):
|
|
"""Initialize model weights."""
|
|
for m in model.modules():
|
|
if isinstance(m, nn.Conv2d):
|
|
nn.init.normal_(m.weight.data, 0.0, 0.02)
|
|
elif isinstance(m, nn.BatchNorm2d):
|
|
nn.init.normal_(m.weight.data, 1.0, 0.02)
|
|
nn.init.constant_(m.bias.data, 0.0)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# Quick test
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
model = create_gan(use_cuda=False)
|
|
print(f"Model created on {model.device}")
|
|
|
|
# Test forward pass
|
|
test_input = torch.randn(2, 3, 256, 256).to(model.device)
|
|
output = model(test_input)
|
|
print(f"Output shape: {output.shape}")
|
|
|
|
# Count parameters
|
|
gen_params = sum(p.numel() for p in model.generator.parameters())
|
|
disc_params = sum(p.numel() for p in model.discriminator.parameters())
|
|
print(f"Generator: {gen_params:,} params, Discriminator: {disc_params:,} params") |