initialize weights, enchance graphics

This commit is contained in:
2026-04-06 22:11:41 +03:00
parent 3372ee4055
commit 2ec0763e6d
6 changed files with 177 additions and 77 deletions

View File

@@ -78,15 +78,18 @@
"\n",
"\n",
"def matrix_to_homography_params(H, K):\n",
" if hasattr(H, 'numpy'):\n",
" H = H.numpy()\n",
" K_inv = np.linalg.inv(K)\n",
" E = K_inv @ H @ K\n",
" scale = np.sqrt(np.linalg.det(E[:2, :2]))\n",
" R = E[:2, :2] / scale\n",
" tx, ty = E[0, 2], E[1, 2]\n",
" rz = np.arctan2(R[1, 0], R[0, 0])\n",
" r20, r21 = E[2, 0], E[2, 1]\n",
" ry = np.arctan2(r20, r21)\n",
" rx = np.arctan2(-E[1, 2], E[1, 1])\n",
" scale = E[2, 2]\n",
" R_normalized = E / scale\n",
" rz = np.arctan2(R_normalized[1, 0], R_normalized[0, 0])\n",
" ry = np.arctan2(-R_normalized[2, 0], np.sqrt(R_normalized[2, 1]**2 + R_normalized[2, 2]**2))\n",
" rx = np.arctan2(R_normalized[2, 1], R_normalized[2, 2])\n",
" A = R_normalized[:2, :2]\n",
" correction = scale * np.array([R_normalized[0, 2], R_normalized[1, 2]])\n",
" tx, ty = np.linalg.solve(A, E[:2, 2] - correction)\n",
" return np.array([tx, ty, rx, ry, rz, scale], dtype=np.float32)\n",
"\n"
]
@@ -149,6 +152,7 @@
" self._cached_google = [None] * len(self.image_pairs)\n",
" self._cached_yandex = [None] * len(self.image_pairs)\n",
" self._cached_homography = [None] * len(self.image_pairs)\n",
" self._cached_params = [None] * len(self.image_pairs)\n",
"\n",
" def _generate_augmented(self, idx):\n",
" google_img = self._google_images[idx].copy()\n",
@@ -158,14 +162,11 @@
" params2 = generate_random_homography_params()\n",
" H1 = homography_params_to_matrix(params1, self.K)\n",
" H2 = homography_params_to_matrix(params2, self.K)\n",
" H_combined = np.linalg.inv(H1) @ H2\n",
" \n",
" google_warped = cv2.warpPerspective(google_img, H2, (self.image_size[1], self.image_size[0]))\n",
" yandex_warped = cv2.warpPerspective(yandex_img, H1, (self.image_size[1], self.image_size[0]))\n",
" google_warped = cv2.warpPerspective(google_img, H2 @ H1, (self.image_size[1], self.image_size[0]))\n",
" \n",
" target_params = matrix_to_homography_params(H_combined, self.K)\n",
" \n",
" return google_warped, yandex_warped, H_combined, target_params\n",
" return google_warped, yandex_warped, H2, params2\n",
"\n",
" def __len__(self):\n",
" return len(self.image_pairs)\n",
@@ -179,13 +180,14 @@
" google_img = self._cached_google[idx]\n",
" yandex_img = self._cached_yandex[idx]\n",
" target_matrix = self._cached_homography[idx]\n",
" target_params = matrix_to_homography_params(target_matrix, self.K)\n",
" target_params = self._cached_params[idx]\n",
" elif self.augment:\n",
" google_img, yandex_img, target_matrix, target_params = self._generate_augmented(idx)\n",
" if self.cache_level > 0:\n",
" self._cached_google[idx] = google_img\n",
" self._cached_yandex[idx] = yandex_img\n",
" self._cached_homography[idx] = target_matrix\n",
" self._cached_params[idx] = target_params\n",
" else:\n",
" google_img = self._google_images[idx]\n",
" yandex_img = self._yandex_images[idx]\n",
@@ -238,6 +240,29 @@
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n",
"train_loader, val_loader = create_data_loaders(config['data_dir'])\n",
"batch = next(iter(train_loader))\n",
"google_img = batch['google_img'][0]\n",
"yandex_img = batch['yandex_img'][0]\n",
"\n",
"# google_img.permute((1, 2, 0)) * 255\n",
"batch['homography_params'].mean(axis=0)\n",
"\n",
"print(batch['homography_matrix'][0])\n",
"print(batch['homography_params'][0])\n",
"K = get_camera_matrix(config['image_size'][0], config['image_size'][1])\n",
"print(homography_params_to_matrix(batch['homography_params'][0], K))\n",
"print(matrix_to_homography_params(batch['homography_matrix'][0].numpy(), K))\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
@@ -291,7 +316,7 @@
"\n",
" output = torch.tanh(output) # [-1; 1]\n",
" modified = output.clone()\n",
" modified[:, 2:5] = torch.mul(output[:, 2:5], torch.pi) # [-pi; pi]\n",
" modified[:, 2:6] = torch.mul(output[:, 2:6], torch.pi) # [-pi; pi]\n",
"\n",
" return modified\n",
"\n",