ref: create vision_chunk
This commit is contained in:
161
vision_chunk.py
Normal file
161
vision_chunk.py
Normal file
@@ -0,0 +1,161 @@
|
||||
import cv2
|
||||
import json
|
||||
import numpy as np
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Literal, Optional, Tuple
|
||||
from PIL import Image
|
||||
|
||||
FeatureMethod = Literal["orb", "sift", "surf"]
|
||||
|
||||
@dataclass
|
||||
class VisionChunk:
|
||||
image: Image.Image
|
||||
feature_method: FeatureMethod = "orb"
|
||||
|
||||
keypoints: Optional[list] = field(default=None, init=False)
|
||||
descriptors: Optional[np.ndarray] = field(default=None, init=False)
|
||||
_detector: Optional[cv2.Feature2D] = field(default=None, init=False, repr=False)
|
||||
_matcher: Optional[cv2.DescriptorMatcher] = field(default=None, init=False, repr=False)
|
||||
|
||||
def _get_detector(self) -> cv2.Feature2D:
|
||||
if self._detector is not None:
|
||||
return self._detector
|
||||
|
||||
if self.feature_method == "orb":
|
||||
self._detector = cv2.ORB_create(
|
||||
nfeatures=1000,
|
||||
scaleFactor=1.2,
|
||||
nlevels=8,
|
||||
edgeThreshold=31,
|
||||
firstLevel=0,
|
||||
WTA_K=2,
|
||||
patchSize=31,
|
||||
fastThreshold=20,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported feature method: {self.feature_method}")
|
||||
return self._detector
|
||||
|
||||
def _get_matcher(self) -> cv2.DescriptorMatcher:
|
||||
if self._matcher is None:
|
||||
self._matcher = cv2.BFMatcher(cv2.NORM_HAMMING, crossCheck=False)
|
||||
return self._matcher
|
||||
|
||||
def _preprocess(self, img_np: np.ndarray) -> np.ndarray:
|
||||
"""CLAHE предобработка для улучшения контраста"""
|
||||
if len(img_np.shape) == 3:
|
||||
gray = cv2.cvtColor(img_np, cv2.COLOR_RGB2GRAY)
|
||||
else:
|
||||
gray = img_np
|
||||
|
||||
clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8, 8))
|
||||
return clahe.apply(gray)
|
||||
|
||||
def compute_keypoints(self, force: bool = False) -> Tuple[list, Optional[np.ndarray]]:
|
||||
if self.keypoints is not None and self.descriptors is not None and not force:
|
||||
return self.keypoints, self.descriptors
|
||||
|
||||
detector = self._get_detector()
|
||||
|
||||
# PIL -> OpenCV (RGB->BGR)
|
||||
img_np = np.array(self.image)
|
||||
if img_np.ndim == 3 and img_np.shape[2] == 3:
|
||||
img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
|
||||
|
||||
# CLAHE предобработка
|
||||
preprocessed = self._preprocess(img_np)
|
||||
kps, desc = detector.detectAndCompute(preprocessed, None)
|
||||
|
||||
self.keypoints = kps
|
||||
self.descriptors = desc
|
||||
return kps, desc
|
||||
|
||||
def detect_and_match_keypoints(
|
||||
self,
|
||||
other: "VisionChunk"
|
||||
) -> Tuple[
|
||||
Optional[np.ndarray],
|
||||
Optional[np.ndarray],
|
||||
Optional[list],
|
||||
Optional[list],
|
||||
Optional[list]
|
||||
]:
|
||||
"""
|
||||
Возвращает: src_pts, dst_pts, good_matches, kp1, kp2 (отцентрированные координаты)
|
||||
"""
|
||||
# Вычисляем keypoints для обоих
|
||||
kp1, des1 = self.compute_keypoints()
|
||||
kp2, des2 = other.compute_keypoints()
|
||||
|
||||
if des1 is None or des2 is None or len(kp1) < 4 or len(kp2) < 4:
|
||||
return None, None, None, None, None
|
||||
|
||||
# kNN matching + Lowe ratio test
|
||||
matcher = self._get_matcher()
|
||||
matches_knn = matcher.knnMatch(des1, des2, k=2)
|
||||
good_matches: list[cv2.DMatch] = []
|
||||
|
||||
for m_n in matches_knn:
|
||||
if len(m_n) < 2:
|
||||
continue
|
||||
m, n = m_n
|
||||
if m.distance < 0.75 * n.distance:
|
||||
good_matches.append(m)
|
||||
|
||||
# Фильтрация по расстоянию (мягкий порог 64)
|
||||
good_matches = sorted(good_matches, key=lambda x: x.distance)
|
||||
good_matches = [m for m in good_matches if m.distance < 64]
|
||||
|
||||
if len(good_matches) < 4:
|
||||
return None, None, None, None, None
|
||||
|
||||
# Центр изображений
|
||||
img1_cv = self.to_cv2_gray()
|
||||
img2_cv = other.to_cv2_gray()
|
||||
h1, w1 = img1_cv.shape
|
||||
h2, w2 = img2_cv.shape
|
||||
cx1, cy1 = w1 // 2, h1 // 2
|
||||
cx2, cy2 = w2 // 2, h2 // 2
|
||||
|
||||
# Отцентрированные координаты (x_rel, y_rel)
|
||||
src_pts = []
|
||||
dst_pts = []
|
||||
|
||||
for match in good_matches:
|
||||
pt1 = kp1[match.queryIdx].pt
|
||||
src_pts.append([pt1[0] - cx1, cy1 - pt1[1]]) # Y вверх!
|
||||
|
||||
pt2 = kp2[match.trainIdx].pt
|
||||
dst_pts.append([pt2[0] - cx2, cy2 - pt2[1]])
|
||||
|
||||
src_pts = np.float32(src_pts).reshape(-1, 1, 2)
|
||||
dst_pts = np.float32(dst_pts).reshape(-1, 1, 2)
|
||||
|
||||
return src_pts, dst_pts, good_matches, kp1, kp2
|
||||
|
||||
def to_cv2_gray(self) -> np.ndarray:
|
||||
"""PIL -> OpenCV grayscale с предобработкой"""
|
||||
img_np = np.array(self.image)
|
||||
if img_np.ndim == 3:
|
||||
gray = cv2.cvtColor(img_np, cv2.COLOR_RGB2GRAY)
|
||||
else:
|
||||
gray = img_np
|
||||
return self._preprocess(img_np)
|
||||
|
||||
def get_shape(self) -> Tuple[int, int]:
|
||||
return self.image.height, self.image.width
|
||||
|
||||
def save_image(self, path: Path | str, format: str = "PNG") -> None:
|
||||
path = Path(path)
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
self.image.save(path, format=format.upper())
|
||||
|
||||
def to_numpy(self) -> np.ndarray:
|
||||
return np.array(self.image)
|
||||
|
||||
@classmethod
|
||||
def load_image(cls, path: Path | str, feature_method: FeatureMethod = "orb") -> "VisionChunk":
|
||||
path = Path(path)
|
||||
image = Image.open(path)
|
||||
return cls(image=image, feature_method=feature_method)
|
||||
Reference in New Issue
Block a user