from __future__ import annotations import importlib.util import os from pathlib import Path from typing import Optional import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from PIL import Image from vision_chunk import VisionChunk ROOT_DIR = Path(__file__).resolve().parent MODEL_FILE = ROOT_DIR / "models" / "GAN" / "src" / "model.py" DEFAULT_CHECKPOINT_PATH = ROOT_DIR / "models" / "TrainedWeights" / "GAN.pth" IMAGE_SIZE = (256, 256) CHECKPOINT_ENV = "GAN_CHECKPOINT" _generator: Optional[torch.nn.Module] = None _device: Optional[torch.device] = None _translated_chunks: dict[int, VisionChunk] = {} class _LegacyUNetUpBlock(nn.Module): """Upsampling block used by earlier GAN checkpoints in models/GAN/runs.""" def __init__(self, in_channels: int, out_channels: int, dropout: float = 0.0): super().__init__() layers = [ nn.ConvTranspose2d( in_channels, out_channels, kernel_size=4, stride=2, padding=1, bias=False, ), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True), ] if dropout > 0: layers.append(nn.Dropout2d(dropout)) self.model = nn.Sequential(*layers) def forward(self, x: torch.Tensor, skip_input: torch.Tensor) -> torch.Tensor: x = self.model(x) if x.shape != skip_input.shape: diff_h = skip_input.size(2) - x.size(2) diff_w = skip_input.size(3) - x.size(3) x = F.pad(x, [diff_w // 2, diff_w - diff_w // 2, diff_h // 2, diff_h - diff_h // 2]) return torch.cat([x, skip_input], dim=1) class _LegacyGeneratorUNet(nn.Module): """Generator architecture matching old ConvTranspose2d checkpoints.""" def __init__(self, down_block_cls, in_channels: int = 3, out_channels: int = 3): super().__init__() self.down1 = down_block_cls(in_channels, 64, normalize=False) self.down2 = down_block_cls(64, 128) self.down3 = down_block_cls(128, 256) self.down4 = down_block_cls(256, 512) self.down5 = down_block_cls(512, 512) self.down6 = down_block_cls(512, 512) self.down7 = down_block_cls(512, 512) self.bottleneck = nn.Sequential( nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1, bias=False), nn.ReLU(inplace=True), ) self.up1 = _LegacyUNetUpBlock(512, 512, dropout=0.5) self.up2 = _LegacyUNetUpBlock(1024, 512, dropout=0.5) self.up3 = _LegacyUNetUpBlock(512, 512, dropout=0.5) self.up4 = _LegacyUNetUpBlock(1024, 512) self.up5 = _LegacyUNetUpBlock(1024, 256) self.up6 = _LegacyUNetUpBlock(512, 128) self.up7 = _LegacyUNetUpBlock(256, 64) self.final = nn.Sequential( nn.ConvTranspose2d(128, out_channels, kernel_size=4, stride=2, padding=1), nn.Tanh(), ) def forward(self, x: torch.Tensor) -> torch.Tensor: d1 = self.down1(x) d2 = self.down2(d1) d3 = self.down3(d2) d4 = self.down4(d3) d5 = self.down5(d4) u = self.bottleneck(d5) u = self.up3(u, d5) u = self.up4(u, d4) u = self.up5(u, d3) u = self.up6(u, d2) u = self.up7(u, d1) return self.final(u) class _NamedTransposeUNetUpBlock(nn.Module): """ConvTranspose block with parameter names used by best.pth.""" def __init__(self, in_channels: int, out_channels: int, dropout: float = 0.0): super().__init__() self.upconv = nn.ConvTranspose2d( in_channels, out_channels, kernel_size=4, stride=2, padding=1, bias=False, ) self.norm = nn.BatchNorm2d(out_channels) self.relu = nn.ReLU(inplace=True) self.dropout = nn.Dropout2d(dropout) if dropout > 0 else None def forward(self, x: torch.Tensor, skip_input: torch.Tensor) -> torch.Tensor: x = self.upconv(x) if x.shape != skip_input.shape: diff_h = skip_input.size(2) - x.size(2) diff_w = skip_input.size(3) - x.size(3) x = F.pad(x, [diff_w // 2, diff_w - diff_w // 2, diff_h // 2, diff_h - diff_h // 2]) x = self.norm(x) x = self.relu(x) if self.dropout is not None: x = self.dropout(x) return torch.cat([x, skip_input], dim=1) class _NamedTransposeGeneratorUNet(nn.Module): """Full U-Net architecture matching checkpoints with upN.upconv.weight.""" def __init__(self, down_block_cls, in_channels: int = 3, out_channels: int = 3): super().__init__() self.down1 = down_block_cls(in_channels, 64, normalize=False) self.down2 = down_block_cls(64, 128) self.down3 = down_block_cls(128, 256) self.down4 = down_block_cls(256, 512) self.down5 = down_block_cls(512, 512) self.down6 = down_block_cls(512, 512) self.down7 = down_block_cls(512, 512) self.bottleneck = nn.Sequential( nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1, bias=False), nn.ReLU(inplace=True), ) self.up1 = _NamedTransposeUNetUpBlock(512, 512, dropout=0.5) self.up2 = _NamedTransposeUNetUpBlock(1024, 512, dropout=0.5) self.up3 = _NamedTransposeUNetUpBlock(1024, 512, dropout=0.5) self.up4 = _NamedTransposeUNetUpBlock(1024, 512) self.up5 = _NamedTransposeUNetUpBlock(1024, 256) self.up6 = _NamedTransposeUNetUpBlock(512, 128) self.up7 = _NamedTransposeUNetUpBlock(256, 64) self.final = nn.Sequential( nn.ConvTranspose2d(128, out_channels, kernel_size=4, stride=2, padding=1), nn.Tanh(), ) def forward(self, x: torch.Tensor) -> torch.Tensor: d1 = self.down1(x) d2 = self.down2(d1) d3 = self.down3(d2) d4 = self.down4(d3) d5 = self.down5(d4) d6 = self.down6(d5) d7 = self.down7(d6) u = self.bottleneck(d7) u = self.up1(u, d7) u = self.up2(u, d6) u = self.up3(u, d5) u = self.up4(u, d4) u = self.up5(u, d3) u = self.up6(u, d2) u = self.up7(u, d1) return self.final(u) def _load_gan_module(): spec = importlib.util.spec_from_file_location("gan_model", MODEL_FILE) if spec is None or spec.loader is None: raise ImportError(f"Cannot load GAN model from {MODEL_FILE}") module = importlib.util.module_from_spec(spec) spec.loader.exec_module(module) return module def _get_checkpoint_path() -> Path: checkpoint_path = os.getenv(CHECKPOINT_ENV) if checkpoint_path: return Path(checkpoint_path).expanduser().resolve() return DEFAULT_CHECKPOINT_PATH def _get_device() -> torch.device: global _device if _device is None: gan_module = _load_gan_module() if hasattr(gan_module, "get_compatible_device"): _device = gan_module.get_compatible_device(prefer_cuda=True) else: _device = torch.device("cuda" if torch.cuda.is_available() else "cpu") return _device def _extract_generator_state_dict(checkpoint) -> dict: if not isinstance(checkpoint, dict): return checkpoint if "generator" in checkpoint: return checkpoint["generator"] if "generator_state_dict" in checkpoint: return checkpoint["generator_state_dict"] state_dict = checkpoint.get("model_state_dict", checkpoint) if any(key.startswith("generator.") for key in state_dict): return { key.removeprefix("generator."): value for key, value in state_dict.items() if key.startswith("generator.") } return state_dict def _get_generator() -> torch.nn.Module: global _generator if _generator is not None: return _generator checkpoint_path = _get_checkpoint_path() if not checkpoint_path.exists(): raise FileNotFoundError( f"GAN checkpoint not found: {checkpoint_path}. " f"Set {CHECKPOINT_ENV} to another .pth file if needed." ) gan_module = _load_gan_module() device = _get_device() checkpoint = torch.load(checkpoint_path, map_location=device) state_dict = _extract_generator_state_dict(checkpoint) if any(key.endswith(".upconv.weight") for key in state_dict): generator = _NamedTransposeGeneratorUNet(gan_module.UNetDownBlock, in_channels=3, out_channels=3).to(device) elif "final.0.weight" in state_dict: generator = _LegacyGeneratorUNet(gan_module.UNetDownBlock, in_channels=3, out_channels=3).to(device) else: generator = gan_module.GeneratorUNet(in_channels=3, out_channels=3).to(device) generator.load_state_dict(state_dict) generator.eval() _generator = generator return _generator def _chunk_to_tensor(chunk: VisionChunk) -> torch.Tensor: image = chunk.image.convert("RGB").resize(IMAGE_SIZE, Image.BILINEAR) array = np.asarray(image, dtype=np.float32) / 127.5 - 1.0 tensor = torch.from_numpy(array).permute(2, 0, 1).unsqueeze(0) return tensor.to(_get_device()) def _tensor_to_image(tensor: torch.Tensor, size: tuple[int, int]) -> Image.Image: array = tensor.squeeze(0).detach().cpu().permute(1, 2, 0).numpy() array = ((array + 1.0) * 127.5).clip(0, 255).astype(np.uint8) image = Image.fromarray(array, mode="RGB") if image.size != size: image = image.resize(size, Image.BILINEAR) return image def transform_image(image: Image.Image) -> Image.Image: """Translate a Google-style reference image into the trained GAN target style.""" generator = _get_generator() source = VisionChunk(image=image) tensor = _chunk_to_tensor(source) with torch.inference_mode(): translated = generator(tensor) return _tensor_to_image(translated, image.size) def transform_chunk(chunk: VisionChunk, force: bool = False) -> VisionChunk: """Return a cached GAN-transformed copy of the reference chunk.""" if chunk is None: return chunk cache_key = id(chunk) if not force and cache_key in _translated_chunks: return _translated_chunks[cache_key] translated = VisionChunk( image=transform_image(chunk.image), feature_method=chunk.feature_method, ) translated.pos = chunk.pos _translated_chunks[cache_key] = translated return translated def clear_cache() -> None: _translated_chunks.clear()