201 lines
7.0 KiB
Python
201 lines
7.0 KiB
Python
import cv2
|
||
import json
|
||
import numpy as np
|
||
from PIL import Image
|
||
from dataclasses import dataclass, field
|
||
from pathlib import Path
|
||
from position import Position
|
||
from typing import Literal, Optional, Tuple
|
||
|
||
FeatureMethod = Literal["orb", "sift", "akaze", "brisk"]
|
||
DEFAULT_METHOD = "orb"
|
||
|
||
@dataclass
|
||
class VisionChunk:
|
||
image: Image.Image
|
||
feature_method: FeatureMethod = DEFAULT_METHOD
|
||
|
||
pos: Optional[Position] = field(default=None, init=False)
|
||
keypoints: Optional[list] = field(default=None, init=False)
|
||
descriptors: Optional[np.ndarray] = field(default=None, init=False)
|
||
_detector: Optional[cv2.Feature2D] = field(default=None, init=False, repr=False)
|
||
_matcher: Optional[cv2.DescriptorMatcher] = field(default=None, init=False, repr=False)
|
||
|
||
def _get_detector(self) -> cv2.Feature2D:
|
||
if self._detector is not None:
|
||
return self._detector
|
||
|
||
if self.feature_method == "orb":
|
||
self._detector = cv2.ORB_create(
|
||
nfeatures=10000,
|
||
scaleFactor=1.2,
|
||
nlevels=32,
|
||
edgeThreshold=31,
|
||
firstLevel=0,
|
||
WTA_K=2,
|
||
patchSize=31,
|
||
fastThreshold=20,
|
||
)
|
||
elif self.feature_method == "sift":
|
||
self._detector = cv2.SIFT_create(
|
||
nfeatures=1500,
|
||
nOctaveLayers=2,
|
||
contrastThreshold=0.01,
|
||
edgeThreshold=15,
|
||
sigma=3.3
|
||
)
|
||
elif self.feature_method == "akaze":
|
||
self._detector = cv2.AKAZE_create(
|
||
descriptor_type=cv2.AKAZE_DESCRIPTOR_MLDB,
|
||
descriptor_size=0,
|
||
descriptor_channels=3,
|
||
threshold=0.001,
|
||
nOctaves=4,
|
||
diffusivity=cv2.KAZE_DIFF_PM_G2
|
||
)
|
||
elif self.feature_method == "brisk":
|
||
self._detector = cv2.BRISK_create(
|
||
thresh=70,
|
||
octaves=7,
|
||
patternScale=1.0
|
||
)
|
||
else:
|
||
raise ValueError(f"Unsupported feature method: {self.feature_method}")
|
||
return self._detector
|
||
|
||
def _get_matcher(self) -> cv2.DescriptorMatcher:
|
||
if self._matcher is None:
|
||
if self.feature_method == 'sift':
|
||
self._matcher = cv2.BFMatcher(cv2.NORM_L2, crossCheck=False)
|
||
else:
|
||
self._matcher = cv2.BFMatcher(cv2.NORM_HAMMING, crossCheck=False)
|
||
return self._matcher
|
||
|
||
def _preprocess(self, img_np: np.ndarray) -> np.ndarray:
|
||
"""CLAHE предобработка для улучшения контраста"""
|
||
if len(img_np.shape) == 3:
|
||
gray = cv2.cvtColor(img_np, cv2.COLOR_RGB2GRAY)
|
||
else:
|
||
gray = img_np
|
||
|
||
clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8, 8))
|
||
return clahe.apply(gray)
|
||
|
||
def compute_keypoints(self, force: bool = False) -> Tuple[list[cv2.KeyPoint], Optional[np.ndarray]]:
|
||
if self.keypoints is not None and self.descriptors is not None and not force:
|
||
return self.keypoints, self.descriptors
|
||
|
||
detector = self._get_detector()
|
||
|
||
# PIL -> OpenCV (RGB->BGR)
|
||
img_np = np.array(self.image)
|
||
if img_np.ndim == 3 and img_np.shape[2] == 3:
|
||
img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
|
||
|
||
# CLAHE предобработка
|
||
preprocessed = self._preprocess(img_np)
|
||
keypoints, descriptors = detector.detectAndCompute(preprocessed, None)
|
||
|
||
# Получаем массив response для всех точек
|
||
responses = np.array([kp.response for kp in keypoints])
|
||
|
||
# Находим индексы топ-100
|
||
top_indices = np.argsort(responses)[-2500:][::-1]
|
||
|
||
# Отбираем keypoints и descriptors
|
||
best_keypoints = [keypoints[i] for i in top_indices]
|
||
best_descriptors = descriptors[top_indices]
|
||
|
||
self.keypoints = best_keypoints
|
||
self.descriptors = best_descriptors
|
||
return self.keypoints, self.descriptors
|
||
|
||
def detect_and_match_keypoints(
|
||
self,
|
||
other: "VisionChunk"
|
||
) -> Tuple[
|
||
Optional[np.ndarray],
|
||
Optional[np.ndarray],
|
||
Optional[list],
|
||
Optional[list],
|
||
Optional[list]
|
||
]:
|
||
"""
|
||
Возвращает: src_pts, dst_pts, good_matches, kp1, kp2 (отцентрированные координаты)
|
||
"""
|
||
# Вычисляем keypoints для обоих
|
||
kp1, des1 = self.compute_keypoints()
|
||
kp2, des2 = other.compute_keypoints()
|
||
|
||
if des1 is None or des2 is None or len(kp1) < 4 or len(kp2) < 4:
|
||
return None, None, None, None, None
|
||
|
||
# kNN matching + Lowe ratio test
|
||
matcher = self._get_matcher()
|
||
matches_knn = matcher.knnMatch(des1, des2, k=2)
|
||
good_matches: list[cv2.DMatch] = []
|
||
|
||
for m_n in matches_knn:
|
||
if len(m_n) < 2:
|
||
continue
|
||
m, n = m_n
|
||
if m.distance < 0.75 * n.distance:
|
||
good_matches.append(m)
|
||
|
||
# Фильтрация по расстоянию (мягкий порог 64)
|
||
good_matches = sorted(good_matches, key=lambda x: x.distance)
|
||
good_matches = [m for m in good_matches if m.distance < 64]
|
||
|
||
if len(good_matches) < 4:
|
||
return None, None, None, None, None
|
||
|
||
# Центр изображений
|
||
img1_cv = self.to_cv2_gray()
|
||
img2_cv = other.to_cv2_gray()
|
||
h1, w1 = img1_cv.shape
|
||
h2, w2 = img2_cv.shape
|
||
cx1, cy1 = w1 // 2, h1 // 2
|
||
cx2, cy2 = w2 // 2, h2 // 2
|
||
|
||
# Отцентрированные координаты (x_rel, y_rel)
|
||
src_pts = []
|
||
dst_pts = []
|
||
|
||
for match in good_matches:
|
||
pt1 = kp1[match.queryIdx].pt
|
||
src_pts.append([pt1[0], pt1[1]])
|
||
|
||
pt2 = kp2[match.trainIdx].pt
|
||
dst_pts.append([pt2[0], pt2[1]])
|
||
|
||
src_pts = np.float32(src_pts).reshape(-1, 1, 2)
|
||
dst_pts = np.float32(dst_pts).reshape(-1, 1, 2)
|
||
|
||
return src_pts, dst_pts, good_matches, kp1, kp2
|
||
|
||
def to_cv2_gray(self) -> np.ndarray:
|
||
"""PIL -> OpenCV grayscale с предобработкой"""
|
||
img_np = np.array(self.image)
|
||
if img_np.ndim == 3:
|
||
gray = cv2.cvtColor(img_np, cv2.COLOR_RGB2GRAY)
|
||
else:
|
||
gray = img_np
|
||
return self._preprocess(img_np)
|
||
|
||
def get_shape(self) -> Tuple[int, int]:
|
||
return self.image.height, self.image.width
|
||
|
||
def save_image(self, path: Path | str, format: str = "PNG") -> None:
|
||
path = Path(path)
|
||
path.parent.mkdir(parents=True, exist_ok=True)
|
||
self.image.save(path, format=format.upper())
|
||
|
||
def to_numpy(self) -> np.ndarray:
|
||
return np.array(self.image)
|
||
|
||
@classmethod
|
||
def load_image(cls, path: Path | str, feature_method: FeatureMethod = DEFAULT_METHOD) -> "VisionChunk":
|
||
path = Path(path)
|
||
image = Image.open(path)
|
||
return cls(image=image, feature_method=feature_method)
|