Files
autopilot/sian_similarity.py
2026-05-30 14:49:40 +03:00

146 lines
3.8 KiB
Python

from __future__ import annotations
import importlib.util
import os
from pathlib import Path
from typing import Optional, Sequence
import numpy as np
import torch
from PIL import Image
from vision_chunk import VisionChunk
ROOT_DIR = Path(__file__).resolve().parent
MODEL_FILE = ROOT_DIR / "models" / "SiaN-similarity" / "model.py"
DEFAULT_CHECKPOINT_PATH = (
ROOT_DIR
/ "models"
/ "SiaN-similarity"
/ "runs"
/ "gan_training"
/ "checkpoints"
/ "best_model.pt"
)
IMAGE_SIZE = (256, 256)
DEFAULT_THRESHOLD = 0.5
CHECKPOINT_ENV = "SIAN_SIMILARITY_CHECKPOINT"
THRESHOLD_ENV = "SIAN_SIMILARITY_THRESHOLD"
_model: Optional[torch.nn.Module] = None
_device: Optional[torch.device] = None
def _load_similarity_class():
spec = importlib.util.spec_from_file_location("sian_similarity_model", MODEL_FILE)
if spec is None or spec.loader is None:
raise ImportError(f"Cannot load similarity model from {MODEL_FILE}")
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
return module.SimilarityCNN
def _get_checkpoint_path() -> Path:
checkpoint_path = os.getenv(CHECKPOINT_ENV)
if checkpoint_path:
return Path(checkpoint_path).expanduser().resolve()
return DEFAULT_CHECKPOINT_PATH
def get_threshold() -> float:
threshold = os.getenv(THRESHOLD_ENV)
if threshold is None:
return DEFAULT_THRESHOLD
return float(threshold)
def _get_device() -> torch.device:
global _device
if _device is None:
_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
return _device
def _get_model() -> torch.nn.Module:
global _model
if _model is not None:
return _model
checkpoint_path = _get_checkpoint_path()
if not checkpoint_path.exists():
raise FileNotFoundError(
f"SiaN similarity checkpoint not found: {checkpoint_path}. "
f"Set {CHECKPOINT_ENV} to another .pt file if needed."
)
SimilarityCNN = _load_similarity_class()
device = _get_device()
model = SimilarityCNN(
input_channels=3,
backbone_name="resnet18",
pretrained=False,
dropout_rate=0.3,
use_batch_norm=True,
).to(device)
checkpoint = torch.load(checkpoint_path, map_location=device)
state_dict = checkpoint.get("model_state_dict", checkpoint)
model.load_state_dict(state_dict)
model.eval()
_model = model
return _model
def _chunk_to_tensor(chunk: VisionChunk) -> torch.Tensor:
image = chunk.image.convert("RGB").resize(IMAGE_SIZE, Image.BILINEAR)
array = np.asarray(image, dtype=np.float32) / 255.0
tensor = torch.from_numpy(array).permute(2, 0, 1).unsqueeze(0)
return tensor.to(_get_device())
def get_similarity_score(chunk1: VisionChunk, chunk2: VisionChunk) -> float:
if chunk1 is None or chunk2 is None:
return 0.0
model = _get_model()
img1 = _chunk_to_tensor(chunk1)
img2 = _chunk_to_tensor(chunk2)
with torch.inference_mode():
similarity = model(img1, img2)
return float(similarity.squeeze().item())
def get_similarity_scores(chunk: VisionChunk, candidates: Sequence[VisionChunk]) -> list[float]:
if chunk is None or not candidates:
return []
model = _get_model()
img = _chunk_to_tensor(chunk)
candidate_images = torch.cat([_chunk_to_tensor(candidate) for candidate in candidates], dim=0)
repeated_img = img.expand(candidate_images.shape[0], -1, -1, -1)
with torch.inference_mode():
similarities = model(repeated_img, candidate_images)
return [float(score) for score in similarities.squeeze(1).detach().cpu().tolist()]
def is_similar(
chunk1: VisionChunk,
chunk2: VisionChunk,
threshold: Optional[float] = None,
) -> bool:
if threshold is None:
threshold = get_threshold()
return get_similarity_score(chunk1, chunk2) >= threshold