"""GAN model for image translation Yandex -> Google.""" import torch import torch.nn as nn import torch.nn.functional as F class UNetDownBlock(nn.Module): """Downsampling block for U-Net.""" def __init__(self, in_channels: int, out_channels: int, normalize: bool = True, dropout: float = 0.0): super().__init__() layers = [ nn.Conv2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1, bias=False) ] if normalize: layers.append(nn.BatchNorm2d(out_channels)) layers.append(nn.LeakyReLU(0.2, inplace=True)) if dropout > 0: layers.append(nn.Dropout2d(dropout)) self.model = nn.Sequential(*layers) def forward(self, x): return self.model(x) class UNetUpBlock(nn.Module): """Upsampling block for U-Net.""" def __init__(self, in_channels: int, out_channels: int, dropout: float = 0.0): super().__init__() self.upconv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1, bias=False) self.norm = nn.BatchNorm2d(out_channels) self.relu = nn.ReLU(inplace=True) if dropout > 0: self.dropout = nn.Dropout2d(dropout) else: self.dropout = None def forward(self, x, skip_input): print(x.shape) print(skip_input.shape) x = self.upconv(x) # Pad if needed to match skip connection size if x.shape != skip_input.shape: diff_h = skip_input.size(2) - x.size(2) diff_w = skip_input.size(3) - x.size(3) x = F.pad(x, [diff_w // 2, diff_w - diff_w // 2, diff_h // 2, diff_h - diff_h // 2]) x = self.norm(x) x = self.relu(x) if self.dropout: x = self.dropout(x) x = torch.cat([x, skip_input], dim=1) return x class GeneratorUNet(nn.Module): """U-Net generator for Yandex -> Google translation.""" def __init__(self, in_channels: int = 3, out_channels: int = 3): super().__init__() # Downsampling self.down1 = UNetDownBlock(in_channels, 64, normalize=False) self.down2 = UNetDownBlock(64, 128) self.down3 = UNetDownBlock(128, 256) self.down4 = UNetDownBlock(256, 512) self.down5 = UNetDownBlock(512, 512) self.down6 = UNetDownBlock(512, 512) self.down7 = UNetDownBlock(512, 512) # Bottleneck self.bottleneck = nn.Sequential( nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1, bias=False), nn.ReLU(inplace=True), ) # Upsampling self.up1 = UNetUpBlock(1024, 512, dropout=0.5) self.up2 = UNetUpBlock(1024, 512, dropout=0.5) self.up3 = UNetUpBlock(1024, 512, dropout=0.5) self.up4 = UNetUpBlock(1024, 512) self.up5 = UNetUpBlock(1024, 256) self.up6 = UNetUpBlock(512, 128) self.up7 = UNetUpBlock(256, 64) # Final self.final = nn.Sequential( nn.ConvTranspose2d(128, out_channels, kernel_size=4, stride=2, padding=1), nn.Tanh(), ) def forward(self, x): # Down d1 = self.down1(x) d2 = self.down2(d1) d3 = self.down3(d2) d4 = self.down4(d3) d5 = self.down5(d4) d6 = self.down6(d5) d7 = self.down7(d6) # Bottleneck u = self.bottleneck(d7) # Up with skip connections u = self.up1(u, d7) u = self.up2(u, d6) u = self.up3(u, d5) u = self.up4(u, d4) u = self.up5(u, d3) u = self.up6(u, d2) u = self.up7(u, d1) return self.final(u) class DiscriminatorPatchGAN(nn.Module): """PatchGAN discriminator.""" def __init__(self, in_channels: int = 6): super().__init__() self.model = nn.Sequential( nn.Conv2d(in_channels, 64, kernel_size=4, stride=2, padding=1), nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1), nn.BatchNorm2d(128), nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1), nn.BatchNorm2d(256), nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1), nn.BatchNorm2d(512), nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=1), nn.Sigmoid(), ) def forward(self, img_A, img_B): x = torch.cat([img_A, img_B], dim=1) return self.model(x) class GANLoss(nn.Module): """GAN loss supporting different GAN modes.""" def __init__(self, gan_mode: str = "vanilla", target_real: float = 1.0, target_fake: float = 0.0): super().__init__() self.gan_mode = gan_mode self.register_buffer("real_label", torch.tensor(target_real)) self.register_buffer("fake_label", torch.tensor(target_fake)) if gan_mode == "vanilla": self.loss_fn = nn.BCEWithLogitsLoss() elif gan_mode == "lsgan": self.loss_fn = nn.MSELoss() elif gan_mode == "wgangp": self.loss_fn = None else: raise ValueError(f"Unknown GAN mode: {gan_mode}") def forward(self, prediction: torch.Tensor, target_is_real: bool) -> torch.Tensor: if self.gan_mode in ["vanilla", "lsgan"]: target = self.real_label if target_is_real else self.fake_label target = target.expand_as(prediction) return self.loss_fn(prediction, target) elif self.gan_mode == "wgangp": return -prediction.mean() if target_is_real else prediction.mean() class ImageGAN(nn.Module): """Complete GAN model for image translation.""" def __init__( self, input_channels: int = 3, output_channels: int = 3, gan_mode: str = "vanilla", lambda_L1: float = 100.0, use_cuda: bool = True, ): super().__init__() self.generator = GeneratorUNet(input_channels, output_channels) self.discriminator = DiscriminatorPatchGAN(input_channels + output_channels) self.gan_loss = GANLoss(gan_mode) self.l1_loss = nn.L1Loss() self.lambda_L1 = lambda_L1 self.device = torch.device("cuda" if use_cuda and torch.cuda.is_available() else "cpu") self.to(self.device) def forward(self, yandex_image): """Generate Google image from Yandex.""" return self.generator(yandex_image) def generator_step(self, yandex_img, real_google_img): """Compute generator losses.""" fake_google = self.generator(yandex_img) fake_pred = self.discriminator(yandex_img, fake_google) gan_loss = self.gan_loss(fake_pred, True) l1_loss = self.l1_loss(fake_google, real_google_img) * self.lambda_L1 total_loss = gan_loss + l1_loss return total_loss, gan_loss, l1_loss def discriminator_step(self, yandex_img, real_google_img, fake_google_img): """Compute discriminator losses.""" real_pred = self.discriminator(yandex_img, real_google_img) real_loss = self.gan_loss(real_pred, True) fake_pred = self.discriminator(yandex_img, fake_google_img.detach()) fake_loss = self.gan_loss(fake_pred, False) total_loss = (real_loss + fake_loss) * 0.5 return total_loss, real_loss, fake_loss def create_gan( input_channels: int = 3, output_channels: int = 3, gan_mode: str = "vanilla", lambda_L1: float = 100.0, use_cuda: bool = True, ) -> ImageGAN: """Create a GAN model.""" return ImageGAN( input_channels=input_channels, output_channels=output_channels, gan_mode=gan_mode, lambda_L1=lambda_L1, use_cuda=use_cuda, ) def initialize_weights(model: nn.Module): """Initialize model weights.""" for m in model.modules(): if isinstance(m, nn.Conv2d): nn.init.normal_(m.weight.data, 0.0, 0.02) elif isinstance(m, nn.BatchNorm2d): nn.init.normal_(m.weight.data, 1.0, 0.02) nn.init.constant_(m.bias.data, 0.0) if __name__ == "__main__": # Quick test device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = create_gan(use_cuda=False) print(f"Model created on {model.device}") # Test forward pass test_input = torch.randn(2, 3, 256, 256).to(model.device) output = model(test_input) print(f"Output shape: {output.shape}") # Count parameters gen_params = sum(p.numel() for p in model.generator.parameters()) disc_params = sum(p.numel() for p in model.discriminator.parameters()) print(f"Generator: {gen_params:,} params, Discriminator: {disc_params:,} params")