Files
autopilot/models/GAN/model.py
2026-04-04 17:50:10 +03:00

255 lines
9.1 KiB
Python

"""GAN model for image translation Yandex -> Google."""
import torch
import torch.nn as nn
import torch.nn.functional as F
class UNetDownBlock(nn.Module):
"""Downsampling block for U-Net."""
def __init__(self, in_channels: int, out_channels: int, normalize: bool = True, dropout: float = 0.0):
super().__init__()
layers = [
nn.Conv2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1, bias=False)
]
if normalize:
layers.append(nn.BatchNorm2d(out_channels))
layers.append(nn.LeakyReLU(0.2, inplace=True))
if dropout > 0:
layers.append(nn.Dropout2d(dropout))
self.model = nn.Sequential(*layers)
def forward(self, x):
return self.model(x)
class UNetUpBlock(nn.Module):
"""Upsampling block for U-Net."""
def __init__(self, in_channels: int, out_channels: int, dropout: float = 0.0):
super().__init__()
self.upconv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1, bias=False)
self.norm = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(inplace=True)
if dropout > 0:
self.dropout = nn.Dropout2d(dropout)
else:
self.dropout = None
def forward(self, x, skip_input):
x = self.upconv(x)
# Pad if needed to match skip connection size
if x.shape != skip_input.shape:
diff_h = skip_input.size(2) - x.size(2)
diff_w = skip_input.size(3) - x.size(3)
x = F.pad(x, [diff_w // 2, diff_w - diff_w // 2, diff_h // 2, diff_h - diff_h // 2])
x = self.norm(x)
x = self.relu(x)
if self.dropout:
x = self.dropout(x)
x = torch.cat([x, skip_input], dim=1)
return x
class GeneratorUNet(nn.Module):
"""U-Net generator for Yandex -> Google translation."""
def __init__(self, in_channels: int = 3, out_channels: int = 3):
super().__init__()
# Downsampling
self.down1 = UNetDownBlock(in_channels, 64, normalize=False)
self.down2 = UNetDownBlock(64, 128)
self.down3 = UNetDownBlock(128, 256)
self.down4 = UNetDownBlock(256, 512)
self.down5 = UNetDownBlock(512, 512)
self.down6 = UNetDownBlock(512, 512)
self.down7 = UNetDownBlock(512, 512)
# Bottleneck
self.bottleneck = nn.Sequential(
nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1, bias=False),
nn.ReLU(inplace=True),
)
# Upsampling - input channels from previous layer, output before concat
self.up1 = UNetUpBlock(512, 512, dropout=0.5) # in: 512 (bottleneck) -> out: 512, concat with d7 (512) = 1024
self.up2 = UNetUpBlock(1024, 512, dropout=0.5) # in: 1024 -> out: 512, concat with d6 (512) = 1024
self.up3 = UNetUpBlock(1024, 512, dropout=0.5) # in: 1024 -> out: 512, concat with d5 (512) = 1024
self.up4 = UNetUpBlock(1024, 512) # in: 1024 -> out: 512, concat with d4 (512) = 1024
self.up5 = UNetUpBlock(1024, 256) # in: 1024 -> out: 256, concat with d3 (256) = 512
self.up6 = UNetUpBlock(512, 128) # in: 512 -> out: 128, concat with d2 (128) = 256
self.up7 = UNetUpBlock(256, 64) # in: 256 -> out: 64, concat with d1 (64) = 128
# Final
self.final = nn.Sequential(
nn.ConvTranspose2d(128, out_channels, kernel_size=4, stride=2, padding=1),
nn.Tanh(),
)
def forward(self, x):
# Down
d1 = self.down1(x)
d2 = self.down2(d1)
d3 = self.down3(d2)
d4 = self.down4(d3)
d5 = self.down5(d4)
d6 = self.down6(d5)
d7 = self.down7(d6)
# Bottleneck
u = self.bottleneck(d7)
# Up with skip connections
u = self.up1(u, d7)
u = self.up2(u, d6)
u = self.up3(u, d5)
u = self.up4(u, d4)
u = self.up5(u, d3)
u = self.up6(u, d2)
u = self.up7(u, d1)
return self.final(u)
class DiscriminatorPatchGAN(nn.Module):
"""PatchGAN discriminator."""
def __init__(self, in_channels: int = 6):
super().__init__()
self.model = nn.Sequential(
nn.Conv2d(in_channels, 64, kernel_size=4, stride=2, padding=1),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(128),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(256),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(512),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=1),
nn.Sigmoid(),
)
def forward(self, img_A, img_B):
x = torch.cat([img_A, img_B], dim=1)
return self.model(x)
class GANLoss(nn.Module):
"""GAN loss supporting different GAN modes."""
def __init__(self, gan_mode: str = "vanilla", target_real: float = 1.0, target_fake: float = 0.0):
super().__init__()
self.gan_mode = gan_mode
self.register_buffer("real_label", torch.tensor(target_real))
self.register_buffer("fake_label", torch.tensor(target_fake))
if gan_mode == "vanilla":
self.loss_fn = nn.BCEWithLogitsLoss()
elif gan_mode == "lsgan":
self.loss_fn = nn.MSELoss()
elif gan_mode == "wgangp":
self.loss_fn = None
else:
raise ValueError(f"Unknown GAN mode: {gan_mode}")
def forward(self, prediction: torch.Tensor, target_is_real: bool) -> torch.Tensor:
if self.gan_mode in ["vanilla", "lsgan"]:
target = self.real_label if target_is_real else self.fake_label
target = target.expand_as(prediction)
return self.loss_fn(prediction, target)
elif self.gan_mode == "wgangp":
return -prediction.mean() if target_is_real else prediction.mean()
class ImageGAN(nn.Module):
"""Complete GAN model for image translation."""
def __init__(
self,
input_channels: int = 3,
output_channels: int = 3,
gan_mode: str = "vanilla",
lambda_L1: float = 100.0,
use_cuda: bool = True,
):
super().__init__()
self.generator = GeneratorUNet(input_channels, output_channels)
self.discriminator = DiscriminatorPatchGAN(input_channels + output_channels)
self.gan_loss = GANLoss(gan_mode)
self.l1_loss = nn.L1Loss()
self.lambda_L1 = lambda_L1
self.device = torch.device("cuda" if use_cuda and torch.cuda.is_available() else "cpu")
self.to(self.device)
def forward(self, yandex_image):
"""Generate Google image from Yandex."""
return self.generator(yandex_image)
def generator_step(self, yandex_img, real_google_img):
"""Compute generator losses."""
fake_google = self.generator(yandex_img)
fake_pred = self.discriminator(yandex_img, fake_google)
gan_loss = self.gan_loss(fake_pred, True)
l1_loss = self.l1_loss(fake_google, real_google_img) * self.lambda_L1
total_loss = gan_loss + l1_loss
return total_loss, gan_loss, l1_loss
def discriminator_step(self, yandex_img, real_google_img, fake_google_img):
"""Compute discriminator losses."""
real_pred = self.discriminator(yandex_img, real_google_img)
real_loss = self.gan_loss(real_pred, True)
fake_pred = self.discriminator(yandex_img, fake_google_img.detach())
fake_loss = self.gan_loss(fake_pred, False)
total_loss = (real_loss + fake_loss) * 0.5
return total_loss, real_loss, fake_loss
def create_gan(
input_channels: int = 3,
output_channels: int = 3,
gan_mode: str = "vanilla",
lambda_L1: float = 100.0,
use_cuda: bool = True,
) -> ImageGAN:
"""Create a GAN model."""
return ImageGAN(
input_channels=input_channels,
output_channels=output_channels,
gan_mode=gan_mode,
lambda_L1=lambda_L1,
use_cuda=use_cuda,
)
def initialize_weights(model: nn.Module):
"""Initialize model weights."""
for m in model.modules():
if isinstance(m, nn.Conv2d):
nn.init.normal_(m.weight.data, 0.0, 0.02)
elif isinstance(m, nn.BatchNorm2d):
nn.init.normal_(m.weight.data, 1.0, 0.02)
nn.init.constant_(m.bias.data, 0.0)
if __name__ == "__main__":
# Quick test
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = create_gan(use_cuda=False)
print(f"Model created on {model.device}")
# Test forward pass
test_input = torch.randn(2, 3, 256, 256).to(model.device)
output = model(test_input)
print(f"Output shape: {output.shape}")
# Count parameters
gen_params = sum(p.numel() for p in model.generator.parameters())
disc_params = sum(p.numel() for p in model.discriminator.parameters())
print(f"Generator: {gen_params:,} params, Discriminator: {disc_params:,} params")