146 lines
3.8 KiB
Python
146 lines
3.8 KiB
Python
from __future__ import annotations
|
|
|
|
import importlib.util
|
|
import os
|
|
from pathlib import Path
|
|
from typing import Optional, Sequence
|
|
|
|
import numpy as np
|
|
import torch
|
|
from PIL import Image
|
|
|
|
from vision_chunk import VisionChunk
|
|
|
|
|
|
ROOT_DIR = Path(__file__).resolve().parent
|
|
MODEL_FILE = ROOT_DIR / "models" / "SiaN-similarity" / "model.py"
|
|
DEFAULT_CHECKPOINT_PATH = (
|
|
ROOT_DIR
|
|
/ "models"
|
|
/ "SiaN-similarity"
|
|
/ "runs"
|
|
/ "gan_training"
|
|
/ "checkpoints"
|
|
/ "best_model.pt"
|
|
)
|
|
|
|
IMAGE_SIZE = (256, 256)
|
|
DEFAULT_THRESHOLD = 0.5
|
|
CHECKPOINT_ENV = "SIAN_SIMILARITY_CHECKPOINT"
|
|
THRESHOLD_ENV = "SIAN_SIMILARITY_THRESHOLD"
|
|
|
|
_model: Optional[torch.nn.Module] = None
|
|
_device: Optional[torch.device] = None
|
|
|
|
|
|
def _load_similarity_class():
|
|
spec = importlib.util.spec_from_file_location("sian_similarity_model", MODEL_FILE)
|
|
if spec is None or spec.loader is None:
|
|
raise ImportError(f"Cannot load similarity model from {MODEL_FILE}")
|
|
|
|
module = importlib.util.module_from_spec(spec)
|
|
spec.loader.exec_module(module)
|
|
return module.SimilarityCNN
|
|
|
|
|
|
def _get_checkpoint_path() -> Path:
|
|
checkpoint_path = os.getenv(CHECKPOINT_ENV)
|
|
if checkpoint_path:
|
|
return Path(checkpoint_path).expanduser().resolve()
|
|
return DEFAULT_CHECKPOINT_PATH
|
|
|
|
|
|
def get_threshold() -> float:
|
|
threshold = os.getenv(THRESHOLD_ENV)
|
|
if threshold is None:
|
|
return DEFAULT_THRESHOLD
|
|
return float(threshold)
|
|
|
|
|
|
def _get_device() -> torch.device:
|
|
global _device
|
|
|
|
if _device is None:
|
|
_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
return _device
|
|
|
|
|
|
def _get_model() -> torch.nn.Module:
|
|
global _model
|
|
|
|
if _model is not None:
|
|
return _model
|
|
|
|
checkpoint_path = _get_checkpoint_path()
|
|
if not checkpoint_path.exists():
|
|
raise FileNotFoundError(
|
|
f"SiaN similarity checkpoint not found: {checkpoint_path}. "
|
|
f"Set {CHECKPOINT_ENV} to another .pt file if needed."
|
|
)
|
|
|
|
SimilarityCNN = _load_similarity_class()
|
|
device = _get_device()
|
|
|
|
model = SimilarityCNN(
|
|
input_channels=3,
|
|
backbone_name="resnet18",
|
|
pretrained=False,
|
|
dropout_rate=0.3,
|
|
use_batch_norm=True,
|
|
).to(device)
|
|
|
|
checkpoint = torch.load(checkpoint_path, map_location=device)
|
|
state_dict = checkpoint.get("model_state_dict", checkpoint)
|
|
model.load_state_dict(state_dict)
|
|
model.eval()
|
|
|
|
_model = model
|
|
return _model
|
|
|
|
|
|
def _chunk_to_tensor(chunk: VisionChunk) -> torch.Tensor:
|
|
image = chunk.image.convert("RGB").resize(IMAGE_SIZE, Image.BILINEAR)
|
|
array = np.asarray(image, dtype=np.float32) / 255.0
|
|
tensor = torch.from_numpy(array).permute(2, 0, 1).unsqueeze(0)
|
|
return tensor.to(_get_device())
|
|
|
|
|
|
def get_similarity_score(chunk1: VisionChunk, chunk2: VisionChunk) -> float:
|
|
if chunk1 is None or chunk2 is None:
|
|
return 0.0
|
|
|
|
model = _get_model()
|
|
img1 = _chunk_to_tensor(chunk1)
|
|
img2 = _chunk_to_tensor(chunk2)
|
|
|
|
with torch.inference_mode():
|
|
similarity = model(img1, img2)
|
|
|
|
return float(similarity.squeeze().item())
|
|
|
|
|
|
def get_similarity_scores(chunk: VisionChunk, candidates: Sequence[VisionChunk]) -> list[float]:
|
|
if chunk is None or not candidates:
|
|
return []
|
|
|
|
model = _get_model()
|
|
img = _chunk_to_tensor(chunk)
|
|
candidate_images = torch.cat([_chunk_to_tensor(candidate) for candidate in candidates], dim=0)
|
|
repeated_img = img.expand(candidate_images.shape[0], -1, -1, -1)
|
|
|
|
with torch.inference_mode():
|
|
similarities = model(repeated_img, candidate_images)
|
|
|
|
return [float(score) for score in similarities.squeeze(1).detach().cpu().tolist()]
|
|
|
|
|
|
def is_similar(
|
|
chunk1: VisionChunk,
|
|
chunk2: VisionChunk,
|
|
threshold: Optional[float] = None,
|
|
) -> bool:
|
|
if threshold is None:
|
|
threshold = get_threshold()
|
|
|
|
return get_similarity_score(chunk1, chunk2) >= threshold
|