Files
autopilot/vision_chunk.py

201 lines
7.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import cv2
import json
import numpy as np
from PIL import Image
from dataclasses import dataclass, field
from pathlib import Path
from position import Position
from typing import Literal, Optional, Tuple
FeatureMethod = Literal["orb", "sift", "akaze", "brisk"]
DEFAULT_METHOD = "orb"
@dataclass
class VisionChunk:
image: Image.Image
feature_method: FeatureMethod = DEFAULT_METHOD
pos: Optional[Position] = field(default=None, init=False)
keypoints: Optional[list] = field(default=None, init=False)
descriptors: Optional[np.ndarray] = field(default=None, init=False)
_detector: Optional[cv2.Feature2D] = field(default=None, init=False, repr=False)
_matcher: Optional[cv2.DescriptorMatcher] = field(default=None, init=False, repr=False)
def _get_detector(self) -> cv2.Feature2D:
if self._detector is not None:
return self._detector
if self.feature_method == "orb":
self._detector = cv2.ORB_create(
nfeatures=10000,
scaleFactor=1.2,
nlevels=32,
edgeThreshold=31,
firstLevel=0,
WTA_K=2,
patchSize=31,
fastThreshold=20,
)
elif self.feature_method == "sift":
self._detector = cv2.SIFT_create(
nfeatures=1500,
nOctaveLayers=2,
contrastThreshold=0.01,
edgeThreshold=15,
sigma=3.3
)
elif self.feature_method == "akaze":
self._detector = cv2.AKAZE_create(
descriptor_type=cv2.AKAZE_DESCRIPTOR_MLDB,
descriptor_size=0,
descriptor_channels=3,
threshold=0.001,
nOctaves=4,
diffusivity=cv2.KAZE_DIFF_PM_G2
)
elif self.feature_method == "brisk":
self._detector = cv2.BRISK_create(
thresh=70,
octaves=7,
patternScale=1.0
)
else:
raise ValueError(f"Unsupported feature method: {self.feature_method}")
return self._detector
def _get_matcher(self) -> cv2.DescriptorMatcher:
if self._matcher is None:
if self.feature_method == 'sift':
self._matcher = cv2.BFMatcher(cv2.NORM_L2, crossCheck=False)
else:
self._matcher = cv2.BFMatcher(cv2.NORM_HAMMING, crossCheck=False)
return self._matcher
def _preprocess(self, img_np: np.ndarray) -> np.ndarray:
"""CLAHE предобработка для улучшения контраста"""
if len(img_np.shape) == 3:
gray = cv2.cvtColor(img_np, cv2.COLOR_RGB2GRAY)
else:
gray = img_np
clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8, 8))
return clahe.apply(gray)
def compute_keypoints(self, force: bool = False) -> Tuple[list[cv2.KeyPoint], Optional[np.ndarray]]:
if self.keypoints is not None and self.descriptors is not None and not force:
return self.keypoints, self.descriptors
detector = self._get_detector()
# PIL -> OpenCV (RGB->BGR)
img_np = np.array(self.image)
if img_np.ndim == 3 and img_np.shape[2] == 3:
img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
# CLAHE предобработка
preprocessed = self._preprocess(img_np)
keypoints, descriptors = detector.detectAndCompute(preprocessed, None)
# Получаем массив response для всех точек
responses = np.array([kp.response for kp in keypoints])
# Находим индексы топ-100
top_indices = np.argsort(responses)[-2500:][::-1]
# Отбираем keypoints и descriptors
best_keypoints = [keypoints[i] for i in top_indices]
best_descriptors = descriptors[top_indices]
self.keypoints = best_keypoints
self.descriptors = best_descriptors
return self.keypoints, self.descriptors
def detect_and_match_keypoints(
self,
other: "VisionChunk"
) -> Tuple[
Optional[np.ndarray],
Optional[np.ndarray],
Optional[list],
Optional[list],
Optional[list]
]:
"""
Возвращает: src_pts, dst_pts, good_matches, kp1, kp2 (отцентрированные координаты)
"""
# Вычисляем keypoints для обоих
kp1, des1 = self.compute_keypoints()
kp2, des2 = other.compute_keypoints()
if des1 is None or des2 is None or len(kp1) < 4 or len(kp2) < 4:
return None, None, None, None, None
# kNN matching + Lowe ratio test
matcher = self._get_matcher()
matches_knn = matcher.knnMatch(des1, des2, k=2)
good_matches: list[cv2.DMatch] = []
for m_n in matches_knn:
if len(m_n) < 2:
continue
m, n = m_n
if m.distance < 0.75 * n.distance:
good_matches.append(m)
# Фильтрация по расстоянию (мягкий порог 64)
good_matches = sorted(good_matches, key=lambda x: x.distance)
good_matches = [m for m in good_matches if m.distance < 64]
if len(good_matches) < 4:
return None, None, None, None, None
# Центр изображений
img1_cv = self.to_cv2_gray()
img2_cv = other.to_cv2_gray()
h1, w1 = img1_cv.shape
h2, w2 = img2_cv.shape
cx1, cy1 = w1 // 2, h1 // 2
cx2, cy2 = w2 // 2, h2 // 2
# Отцентрированные координаты (x_rel, y_rel)
src_pts = []
dst_pts = []
for match in good_matches:
pt1 = kp1[match.queryIdx].pt
src_pts.append([pt1[0], pt1[1]])
pt2 = kp2[match.trainIdx].pt
dst_pts.append([pt2[0], pt2[1]])
src_pts = np.float32(src_pts).reshape(-1, 1, 2)
dst_pts = np.float32(dst_pts).reshape(-1, 1, 2)
return src_pts, dst_pts, good_matches, kp1, kp2
def to_cv2_gray(self) -> np.ndarray:
"""PIL -> OpenCV grayscale с предобработкой"""
img_np = np.array(self.image)
if img_np.ndim == 3:
gray = cv2.cvtColor(img_np, cv2.COLOR_RGB2GRAY)
else:
gray = img_np
return self._preprocess(img_np)
def get_shape(self) -> Tuple[int, int]:
return self.image.height, self.image.width
def save_image(self, path: Path | str, format: str = "PNG") -> None:
path = Path(path)
path.parent.mkdir(parents=True, exist_ok=True)
self.image.save(path, format=format.upper())
def to_numpy(self) -> np.ndarray:
return np.array(self.image)
@classmethod
def load_image(cls, path: Path | str, feature_method: FeatureMethod = DEFAULT_METHOD) -> "VisionChunk":
path = Path(path)
image = Image.open(path)
return cls(image=image, feature_method=feature_method)