feat: gan integration
This commit is contained in:
316
gan.py
Normal file
316
gan.py
Normal file
@@ -0,0 +1,316 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib.util
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from PIL import Image
|
||||
|
||||
from vision_chunk import VisionChunk
|
||||
|
||||
|
||||
ROOT_DIR = Path(__file__).resolve().parent
|
||||
MODEL_FILE = ROOT_DIR / "models" / "GAN" / "src" / "model.py"
|
||||
DEFAULT_CHECKPOINT_PATH = ROOT_DIR / "models" / "GAN" / "runs" / "checkpoints" / "best.pth"
|
||||
|
||||
IMAGE_SIZE = (256, 256)
|
||||
CHECKPOINT_ENV = "GAN_CHECKPOINT"
|
||||
|
||||
_generator: Optional[torch.nn.Module] = None
|
||||
_device: Optional[torch.device] = None
|
||||
_translated_chunks: dict[int, VisionChunk] = {}
|
||||
|
||||
|
||||
class _LegacyUNetUpBlock(nn.Module):
|
||||
"""Upsampling block used by earlier GAN checkpoints in models/GAN/runs."""
|
||||
|
||||
def __init__(self, in_channels: int, out_channels: int, dropout: float = 0.0):
|
||||
super().__init__()
|
||||
layers = [
|
||||
nn.ConvTranspose2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=4,
|
||||
stride=2,
|
||||
padding=1,
|
||||
bias=False,
|
||||
),
|
||||
nn.BatchNorm2d(out_channels),
|
||||
nn.ReLU(inplace=True),
|
||||
]
|
||||
if dropout > 0:
|
||||
layers.append(nn.Dropout2d(dropout))
|
||||
self.model = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x: torch.Tensor, skip_input: torch.Tensor) -> torch.Tensor:
|
||||
x = self.model(x)
|
||||
if x.shape != skip_input.shape:
|
||||
diff_h = skip_input.size(2) - x.size(2)
|
||||
diff_w = skip_input.size(3) - x.size(3)
|
||||
x = F.pad(x, [diff_w // 2, diff_w - diff_w // 2, diff_h // 2, diff_h - diff_h // 2])
|
||||
return torch.cat([x, skip_input], dim=1)
|
||||
|
||||
|
||||
class _LegacyGeneratorUNet(nn.Module):
|
||||
"""Generator architecture matching old ConvTranspose2d checkpoints."""
|
||||
|
||||
def __init__(self, down_block_cls, in_channels: int = 3, out_channels: int = 3):
|
||||
super().__init__()
|
||||
self.down1 = down_block_cls(in_channels, 64, normalize=False)
|
||||
self.down2 = down_block_cls(64, 128)
|
||||
self.down3 = down_block_cls(128, 256)
|
||||
self.down4 = down_block_cls(256, 512)
|
||||
self.down5 = down_block_cls(512, 512)
|
||||
self.down6 = down_block_cls(512, 512)
|
||||
self.down7 = down_block_cls(512, 512)
|
||||
|
||||
self.bottleneck = nn.Sequential(
|
||||
nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1, bias=False),
|
||||
nn.ReLU(inplace=True),
|
||||
)
|
||||
|
||||
self.up1 = _LegacyUNetUpBlock(512, 512, dropout=0.5)
|
||||
self.up2 = _LegacyUNetUpBlock(1024, 512, dropout=0.5)
|
||||
self.up3 = _LegacyUNetUpBlock(512, 512, dropout=0.5)
|
||||
self.up4 = _LegacyUNetUpBlock(1024, 512)
|
||||
self.up5 = _LegacyUNetUpBlock(1024, 256)
|
||||
self.up6 = _LegacyUNetUpBlock(512, 128)
|
||||
self.up7 = _LegacyUNetUpBlock(256, 64)
|
||||
|
||||
self.final = nn.Sequential(
|
||||
nn.ConvTranspose2d(128, out_channels, kernel_size=4, stride=2, padding=1),
|
||||
nn.Tanh(),
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
d1 = self.down1(x)
|
||||
d2 = self.down2(d1)
|
||||
d3 = self.down3(d2)
|
||||
d4 = self.down4(d3)
|
||||
d5 = self.down5(d4)
|
||||
|
||||
u = self.bottleneck(d5)
|
||||
u = self.up3(u, d5)
|
||||
u = self.up4(u, d4)
|
||||
u = self.up5(u, d3)
|
||||
u = self.up6(u, d2)
|
||||
u = self.up7(u, d1)
|
||||
return self.final(u)
|
||||
|
||||
|
||||
class _NamedTransposeUNetUpBlock(nn.Module):
|
||||
"""ConvTranspose block with parameter names used by best.pth."""
|
||||
|
||||
def __init__(self, in_channels: int, out_channels: int, dropout: float = 0.0):
|
||||
super().__init__()
|
||||
self.upconv = nn.ConvTranspose2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=4,
|
||||
stride=2,
|
||||
padding=1,
|
||||
bias=False,
|
||||
)
|
||||
self.norm = nn.BatchNorm2d(out_channels)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.dropout = nn.Dropout2d(dropout) if dropout > 0 else None
|
||||
|
||||
def forward(self, x: torch.Tensor, skip_input: torch.Tensor) -> torch.Tensor:
|
||||
x = self.upconv(x)
|
||||
if x.shape != skip_input.shape:
|
||||
diff_h = skip_input.size(2) - x.size(2)
|
||||
diff_w = skip_input.size(3) - x.size(3)
|
||||
x = F.pad(x, [diff_w // 2, diff_w - diff_w // 2, diff_h // 2, diff_h - diff_h // 2])
|
||||
x = self.norm(x)
|
||||
x = self.relu(x)
|
||||
if self.dropout is not None:
|
||||
x = self.dropout(x)
|
||||
return torch.cat([x, skip_input], dim=1)
|
||||
|
||||
|
||||
class _NamedTransposeGeneratorUNet(nn.Module):
|
||||
"""Full U-Net architecture matching checkpoints with upN.upconv.weight."""
|
||||
|
||||
def __init__(self, down_block_cls, in_channels: int = 3, out_channels: int = 3):
|
||||
super().__init__()
|
||||
self.down1 = down_block_cls(in_channels, 64, normalize=False)
|
||||
self.down2 = down_block_cls(64, 128)
|
||||
self.down3 = down_block_cls(128, 256)
|
||||
self.down4 = down_block_cls(256, 512)
|
||||
self.down5 = down_block_cls(512, 512)
|
||||
self.down6 = down_block_cls(512, 512)
|
||||
self.down7 = down_block_cls(512, 512)
|
||||
|
||||
self.bottleneck = nn.Sequential(
|
||||
nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1, bias=False),
|
||||
nn.ReLU(inplace=True),
|
||||
)
|
||||
|
||||
self.up1 = _NamedTransposeUNetUpBlock(512, 512, dropout=0.5)
|
||||
self.up2 = _NamedTransposeUNetUpBlock(1024, 512, dropout=0.5)
|
||||
self.up3 = _NamedTransposeUNetUpBlock(1024, 512, dropout=0.5)
|
||||
self.up4 = _NamedTransposeUNetUpBlock(1024, 512)
|
||||
self.up5 = _NamedTransposeUNetUpBlock(1024, 256)
|
||||
self.up6 = _NamedTransposeUNetUpBlock(512, 128)
|
||||
self.up7 = _NamedTransposeUNetUpBlock(256, 64)
|
||||
|
||||
self.final = nn.Sequential(
|
||||
nn.ConvTranspose2d(128, out_channels, kernel_size=4, stride=2, padding=1),
|
||||
nn.Tanh(),
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
d1 = self.down1(x)
|
||||
d2 = self.down2(d1)
|
||||
d3 = self.down3(d2)
|
||||
d4 = self.down4(d3)
|
||||
d5 = self.down5(d4)
|
||||
d6 = self.down6(d5)
|
||||
d7 = self.down7(d6)
|
||||
|
||||
u = self.bottleneck(d7)
|
||||
u = self.up1(u, d7)
|
||||
u = self.up2(u, d6)
|
||||
u = self.up3(u, d5)
|
||||
u = self.up4(u, d4)
|
||||
u = self.up5(u, d3)
|
||||
u = self.up6(u, d2)
|
||||
u = self.up7(u, d1)
|
||||
return self.final(u)
|
||||
|
||||
|
||||
def _load_gan_module():
|
||||
spec = importlib.util.spec_from_file_location("gan_model", MODEL_FILE)
|
||||
if spec is None or spec.loader is None:
|
||||
raise ImportError(f"Cannot load GAN model from {MODEL_FILE}")
|
||||
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(module)
|
||||
return module
|
||||
|
||||
|
||||
def _get_checkpoint_path() -> Path:
|
||||
checkpoint_path = os.getenv(CHECKPOINT_ENV)
|
||||
if checkpoint_path:
|
||||
return Path(checkpoint_path).expanduser().resolve()
|
||||
return DEFAULT_CHECKPOINT_PATH
|
||||
|
||||
|
||||
def _get_device() -> torch.device:
|
||||
global _device
|
||||
|
||||
if _device is None:
|
||||
gan_module = _load_gan_module()
|
||||
if hasattr(gan_module, "get_compatible_device"):
|
||||
_device = gan_module.get_compatible_device(prefer_cuda=True)
|
||||
else:
|
||||
_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
return _device
|
||||
|
||||
|
||||
def _extract_generator_state_dict(checkpoint) -> dict:
|
||||
if not isinstance(checkpoint, dict):
|
||||
return checkpoint
|
||||
|
||||
if "generator" in checkpoint:
|
||||
return checkpoint["generator"]
|
||||
|
||||
if "generator_state_dict" in checkpoint:
|
||||
return checkpoint["generator_state_dict"]
|
||||
|
||||
state_dict = checkpoint.get("model_state_dict", checkpoint)
|
||||
if any(key.startswith("generator.") for key in state_dict):
|
||||
return {
|
||||
key.removeprefix("generator."): value
|
||||
for key, value in state_dict.items()
|
||||
if key.startswith("generator.")
|
||||
}
|
||||
|
||||
return state_dict
|
||||
|
||||
|
||||
def _get_generator() -> torch.nn.Module:
|
||||
global _generator
|
||||
|
||||
if _generator is not None:
|
||||
return _generator
|
||||
|
||||
checkpoint_path = _get_checkpoint_path()
|
||||
if not checkpoint_path.exists():
|
||||
raise FileNotFoundError(
|
||||
f"GAN checkpoint not found: {checkpoint_path}. "
|
||||
f"Set {CHECKPOINT_ENV} to another .pth file if needed."
|
||||
)
|
||||
|
||||
gan_module = _load_gan_module()
|
||||
device = _get_device()
|
||||
checkpoint = torch.load(checkpoint_path, map_location=device)
|
||||
state_dict = _extract_generator_state_dict(checkpoint)
|
||||
|
||||
if any(key.endswith(".upconv.weight") for key in state_dict):
|
||||
generator = _NamedTransposeGeneratorUNet(gan_module.UNetDownBlock, in_channels=3, out_channels=3).to(device)
|
||||
elif "final.0.weight" in state_dict:
|
||||
generator = _LegacyGeneratorUNet(gan_module.UNetDownBlock, in_channels=3, out_channels=3).to(device)
|
||||
else:
|
||||
generator = gan_module.GeneratorUNet(in_channels=3, out_channels=3).to(device)
|
||||
|
||||
generator.load_state_dict(state_dict)
|
||||
generator.eval()
|
||||
|
||||
_generator = generator
|
||||
return _generator
|
||||
|
||||
|
||||
def _chunk_to_tensor(chunk: VisionChunk) -> torch.Tensor:
|
||||
image = chunk.image.convert("RGB").resize(IMAGE_SIZE, Image.BILINEAR)
|
||||
array = np.asarray(image, dtype=np.float32) / 127.5 - 1.0
|
||||
tensor = torch.from_numpy(array).permute(2, 0, 1).unsqueeze(0)
|
||||
return tensor.to(_get_device())
|
||||
|
||||
|
||||
def _tensor_to_image(tensor: torch.Tensor, size: tuple[int, int]) -> Image.Image:
|
||||
array = tensor.squeeze(0).detach().cpu().permute(1, 2, 0).numpy()
|
||||
array = ((array + 1.0) * 127.5).clip(0, 255).astype(np.uint8)
|
||||
image = Image.fromarray(array, mode="RGB")
|
||||
if image.size != size:
|
||||
image = image.resize(size, Image.BILINEAR)
|
||||
return image
|
||||
|
||||
|
||||
def transform_image(image: Image.Image) -> Image.Image:
|
||||
"""Translate a Google-style reference image into the trained GAN target style."""
|
||||
generator = _get_generator()
|
||||
source = VisionChunk(image=image)
|
||||
tensor = _chunk_to_tensor(source)
|
||||
|
||||
with torch.inference_mode():
|
||||
translated = generator(tensor)
|
||||
|
||||
return _tensor_to_image(translated, image.size)
|
||||
|
||||
|
||||
def transform_chunk(chunk: VisionChunk, force: bool = False) -> VisionChunk:
|
||||
"""Return a cached GAN-transformed copy of the reference chunk."""
|
||||
if chunk is None:
|
||||
return chunk
|
||||
|
||||
cache_key = id(chunk)
|
||||
if not force and cache_key in _translated_chunks:
|
||||
return _translated_chunks[cache_key]
|
||||
|
||||
translated = VisionChunk(
|
||||
image=transform_image(chunk.image),
|
||||
feature_method=chunk.feature_method,
|
||||
)
|
||||
translated.pos = chunk.pos
|
||||
_translated_chunks[cache_key] = translated
|
||||
return translated
|
||||
|
||||
|
||||
def clear_cache() -> None:
|
||||
_translated_chunks.clear()
|
||||
Reference in New Issue
Block a user