initialize weights, enchance graphics
This commit is contained in:
@@ -13,6 +13,7 @@ import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
import matplotlib.pyplot as plt
|
||||
from PIL import Image
|
||||
import seaborn as sns
|
||||
from torch.utils.data import DataLoader, Dataset, Subset
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from torchvision import transforms, models
|
||||
|
||||
@@ -78,15 +78,18 @@
|
||||
"\n",
|
||||
"\n",
|
||||
"def matrix_to_homography_params(H, K):\n",
|
||||
" if hasattr(H, 'numpy'):\n",
|
||||
" H = H.numpy()\n",
|
||||
" K_inv = np.linalg.inv(K)\n",
|
||||
" E = K_inv @ H @ K\n",
|
||||
" scale = np.sqrt(np.linalg.det(E[:2, :2]))\n",
|
||||
" R = E[:2, :2] / scale\n",
|
||||
" tx, ty = E[0, 2], E[1, 2]\n",
|
||||
" rz = np.arctan2(R[1, 0], R[0, 0])\n",
|
||||
" r20, r21 = E[2, 0], E[2, 1]\n",
|
||||
" ry = np.arctan2(r20, r21)\n",
|
||||
" rx = np.arctan2(-E[1, 2], E[1, 1])\n",
|
||||
" scale = E[2, 2]\n",
|
||||
" R_normalized = E / scale\n",
|
||||
" rz = np.arctan2(R_normalized[1, 0], R_normalized[0, 0])\n",
|
||||
" ry = np.arctan2(-R_normalized[2, 0], np.sqrt(R_normalized[2, 1]**2 + R_normalized[2, 2]**2))\n",
|
||||
" rx = np.arctan2(R_normalized[2, 1], R_normalized[2, 2])\n",
|
||||
" A = R_normalized[:2, :2]\n",
|
||||
" correction = scale * np.array([R_normalized[0, 2], R_normalized[1, 2]])\n",
|
||||
" tx, ty = np.linalg.solve(A, E[:2, 2] - correction)\n",
|
||||
" return np.array([tx, ty, rx, ry, rz, scale], dtype=np.float32)\n",
|
||||
"\n"
|
||||
]
|
||||
@@ -149,6 +152,7 @@
|
||||
" self._cached_google = [None] * len(self.image_pairs)\n",
|
||||
" self._cached_yandex = [None] * len(self.image_pairs)\n",
|
||||
" self._cached_homography = [None] * len(self.image_pairs)\n",
|
||||
" self._cached_params = [None] * len(self.image_pairs)\n",
|
||||
"\n",
|
||||
" def _generate_augmented(self, idx):\n",
|
||||
" google_img = self._google_images[idx].copy()\n",
|
||||
@@ -158,14 +162,11 @@
|
||||
" params2 = generate_random_homography_params()\n",
|
||||
" H1 = homography_params_to_matrix(params1, self.K)\n",
|
||||
" H2 = homography_params_to_matrix(params2, self.K)\n",
|
||||
" H_combined = np.linalg.inv(H1) @ H2\n",
|
||||
" \n",
|
||||
" google_warped = cv2.warpPerspective(google_img, H2, (self.image_size[1], self.image_size[0]))\n",
|
||||
" yandex_warped = cv2.warpPerspective(yandex_img, H1, (self.image_size[1], self.image_size[0]))\n",
|
||||
" google_warped = cv2.warpPerspective(google_img, H2 @ H1, (self.image_size[1], self.image_size[0]))\n",
|
||||
" \n",
|
||||
" target_params = matrix_to_homography_params(H_combined, self.K)\n",
|
||||
" \n",
|
||||
" return google_warped, yandex_warped, H_combined, target_params\n",
|
||||
" return google_warped, yandex_warped, H2, params2\n",
|
||||
"\n",
|
||||
" def __len__(self):\n",
|
||||
" return len(self.image_pairs)\n",
|
||||
@@ -179,13 +180,14 @@
|
||||
" google_img = self._cached_google[idx]\n",
|
||||
" yandex_img = self._cached_yandex[idx]\n",
|
||||
" target_matrix = self._cached_homography[idx]\n",
|
||||
" target_params = matrix_to_homography_params(target_matrix, self.K)\n",
|
||||
" target_params = self._cached_params[idx]\n",
|
||||
" elif self.augment:\n",
|
||||
" google_img, yandex_img, target_matrix, target_params = self._generate_augmented(idx)\n",
|
||||
" if self.cache_level > 0:\n",
|
||||
" self._cached_google[idx] = google_img\n",
|
||||
" self._cached_yandex[idx] = yandex_img\n",
|
||||
" self._cached_homography[idx] = target_matrix\n",
|
||||
" self._cached_params[idx] = target_params\n",
|
||||
" else:\n",
|
||||
" google_img = self._google_images[idx]\n",
|
||||
" yandex_img = self._yandex_images[idx]\n",
|
||||
@@ -238,6 +240,29 @@
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"\n",
|
||||
"train_loader, val_loader = create_data_loaders(config['data_dir'])\n",
|
||||
"batch = next(iter(train_loader))\n",
|
||||
"google_img = batch['google_img'][0]\n",
|
||||
"yandex_img = batch['yandex_img'][0]\n",
|
||||
"\n",
|
||||
"# google_img.permute((1, 2, 0)) * 255\n",
|
||||
"batch['homography_params'].mean(axis=0)\n",
|
||||
"\n",
|
||||
"print(batch['homography_matrix'][0])\n",
|
||||
"print(batch['homography_params'][0])\n",
|
||||
"K = get_camera_matrix(config['image_size'][0], config['image_size'][1])\n",
|
||||
"print(homography_params_to_matrix(batch['homography_params'][0], K))\n",
|
||||
"print(matrix_to_homography_params(batch['homography_matrix'][0].numpy(), K))\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
@@ -291,7 +316,7 @@
|
||||
"\n",
|
||||
" output = torch.tanh(output) # [-1; 1]\n",
|
||||
" modified = output.clone()\n",
|
||||
" modified[:, 2:5] = torch.mul(output[:, 2:5], torch.pi) # [-pi; pi]\n",
|
||||
" modified[:, 2:6] = torch.mul(output[:, 2:6], torch.pi) # [-pi; pi]\n",
|
||||
"\n",
|
||||
" return modified\n",
|
||||
"\n",
|
||||
|
||||
@@ -2,11 +2,14 @@ import os
|
||||
import torch
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import seaborn as sns
|
||||
|
||||
from .dataloader import create_data_loaders
|
||||
from .model import angular_difference
|
||||
from .utils import config
|
||||
|
||||
sns.set_theme(style="whitegrid", palette="muted", font_scale=1.2)
|
||||
|
||||
IMG_DIR = os.path.join(config["output_dir"], "images")
|
||||
os.makedirs(IMG_DIR, exist_ok=True)
|
||||
|
||||
@@ -82,84 +85,132 @@ def analyze_training(trainer):
|
||||
mean_errors = [np.mean(all_errors[i]) for i in range(6)]
|
||||
std_errors = [np.std(all_errors[i]) for i in range(6)]
|
||||
|
||||
angle_errors_deg = [np.degrees(mean_errors[i]) for i in range(2, 5)]
|
||||
|
||||
all_targets_stacked = [np.array(all_targets[i]) for i in range(6)]
|
||||
target_ranges = [np.ptp(all_targets_stacked[i]) for i in range(6)]
|
||||
relative_errors = [mean_errors[i] / target_ranges[i] if target_ranges[i] > 1e-8 else 0 for i in range(6)]
|
||||
|
||||
if len(trainer.train_losses) > 0:
|
||||
epochs = range(1, len(trainer.train_losses) + 1)
|
||||
fig, axes = plt.subplots(2, 2, figsize=(16, 12))
|
||||
|
||||
axes[0, 0].plot(epochs, trainer.train_losses, "b-", label="Train Loss")
|
||||
axes[0, 0].plot(epochs, trainer.val_losses, "r-", label="Val Loss")
|
||||
axes[0, 0].plot(epochs, trainer.train_losses, color="#2ecc71", linewidth=2, label="Train Loss")
|
||||
axes[0, 0].plot(epochs, trainer.val_losses, color="#e74c3c", linewidth=2, label="Val Loss")
|
||||
axes[0, 0].set_xlabel("Epoch")
|
||||
axes[0, 0].set_ylabel("Loss")
|
||||
axes[0, 0].set_title("Training & Validation Loss")
|
||||
axes[0, 0].legend()
|
||||
axes[0, 0].set_title("Training & Validation Loss", fontweight="bold")
|
||||
axes[0, 0].legend(framealpha=0.9)
|
||||
axes[0, 0].grid(True, alpha=0.3)
|
||||
|
||||
axes[0, 1].plot(epochs, trainer.val_losses, "r-", label="Val Loss")
|
||||
axes[0, 1].plot(epochs, trainer.val_losses, color="#e74c3c", linewidth=2, label="Val Loss")
|
||||
axes[0, 1].set_xlabel("Epoch")
|
||||
axes[0, 1].set_ylabel("Loss")
|
||||
axes[0, 1].set_title("Validation Loss")
|
||||
axes[0, 1].legend()
|
||||
axes[0, 1].set_title("Validation Loss", fontweight="bold")
|
||||
axes[0, 1].legend(framealpha=0.9)
|
||||
axes[0, 1].grid(True, alpha=0.3)
|
||||
|
||||
axes[1, 0].plot(epochs, trainer.val_mse_trans, "g-", label="Translation (tx, ty)")
|
||||
axes[1, 0].plot(epochs, trainer.val_mse_angle, "m-", label="Angle (rx, ry, rz)")
|
||||
axes[1, 0].plot(epochs, trainer.val_mse_scale, "c-", label="Scale")
|
||||
axes[1, 0].plot(epochs, trainer.val_mse_trans, color="#3498db", linewidth=2, label="Translation (tx, ty)")
|
||||
axes[1, 0].plot(epochs, trainer.val_mse_angle, color="#9b59b6", linewidth=2, label="Angle (rx, ry, rz)")
|
||||
axes[1, 0].plot(epochs, trainer.val_mse_scale, color="#e67e22", linewidth=2, label="Scale")
|
||||
axes[1, 0].set_xlabel("Epoch")
|
||||
axes[1, 0].set_ylabel("MSE")
|
||||
axes[1, 0].set_title("Validation MSE by Category")
|
||||
axes[1, 0].legend()
|
||||
axes[1, 0].set_title("Validation MSE by Category", fontweight="bold")
|
||||
axes[1, 0].legend(framealpha=0.9)
|
||||
axes[1, 0].grid(True, alpha=0.3)
|
||||
|
||||
x_pos = np.arange(6)
|
||||
axes[1, 1].bar(x_pos, mean_errors, yerr=std_errors, capsize=5, color=["c", "m", "y", "g", "b", "r"], alpha=0.8)
|
||||
colors = ["#3498db", "#e74c3c", "#9b59b6", "#2ecc71", "#f39c12", "#1abc9c"]
|
||||
bars = axes[1, 1].bar(x_pos, mean_errors, yerr=std_errors, capsize=6, color=colors, alpha=0.85, edgecolor="white", linewidth=1.5)
|
||||
axes[1, 1].set_xticks(x_pos)
|
||||
axes[1, 1].set_xticklabels(names)
|
||||
axes[1, 1].set_ylabel("Mean Absolute Error")
|
||||
axes[1, 1].set_title(f"Mean Absolute Error per Parameter ({n_samples} samples)")
|
||||
axes[1, 1].set_title(f"Mean Absolute Error per Parameter ({n_samples} samples)", fontweight="bold")
|
||||
axes[1, 1].grid(True, alpha=0.3, axis="y")
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig(os.path.join(IMG_DIR, "training_loss_plots.png"), dpi=150)
|
||||
plt.savefig(os.path.join(IMG_DIR, "training_loss_plots.png"), dpi=150, bbox_inches="tight")
|
||||
print("Saved training_loss_plots.png")
|
||||
plt.show()
|
||||
|
||||
fig, axes = plt.subplots(2, 3, figsize=(18, 10))
|
||||
colors = ["#3498db", "#e74c3c", "#9b59b6", "#2ecc71", "#f39c12", "#1abc9c"]
|
||||
for j in range(6):
|
||||
row = j // 3
|
||||
col = j % 3
|
||||
axes[row, col].bar(range(len(all_errors[j])), all_errors[j], color="steelblue", alpha=0.7)
|
||||
axes[row, col].set_xlabel("Sample")
|
||||
axes[row, col].set_ylabel("Absolute Error")
|
||||
axes[row, col].set_title(f"{names[j]}: Mean={np.mean(all_errors[j]):.4f}, Std={np.std(all_errors[j]):.4f}")
|
||||
axes[row, col].bar(range(len(all_errors[j])), all_errors[j], color=colors[j], alpha=0.75)
|
||||
axes[row, col].set_xlabel("Sample", fontsize=10)
|
||||
axes[row, col].set_ylabel("Absolute Error", fontsize=10)
|
||||
axes[row, col].set_title(f"{names[j]}: Mean={np.mean(all_errors[j]):.4f}, Std={np.std(all_errors[j]):.4f}", fontweight="bold", fontsize=11)
|
||||
axes[row, col].grid(True, alpha=0.3, axis="y")
|
||||
plt.suptitle(f"Mean Absolute Error per Parameter ({n_samples} samples)", fontsize=14)
|
||||
plt.suptitle(f"Mean Absolute Error per Parameter ({n_samples} samples)", fontsize=14, fontweight="bold")
|
||||
plt.tight_layout()
|
||||
plt.savefig(os.path.join(IMG_DIR, "mae_per_parameter.png"), dpi=150)
|
||||
plt.savefig(os.path.join(IMG_DIR, "mae_per_parameter.png"), dpi=150, bbox_inches="tight")
|
||||
print("Saved mae_per_parameter.png")
|
||||
plt.show()
|
||||
|
||||
fig, axes = plt.subplots(1, 3, figsize=(18, 6))
|
||||
|
||||
x_pos = np.arange(6)
|
||||
colors = ["#3498db", "#e74c3c", "#9b59b6", "#2ecc71", "#f39c12", "#1abc9c"]
|
||||
bars = axes[0].bar(x_pos, mean_errors, yerr=std_errors, capsize=6, color=colors, alpha=0.85, edgecolor="white", linewidth=1.5)
|
||||
axes[0].set_xticks(x_pos)
|
||||
axes[0].set_xticklabels(names)
|
||||
axes[0].set_ylabel("Mean Absolute Error")
|
||||
axes[0].set_title("Mean Absolute Error per Parameter (with std)", fontweight="bold")
|
||||
axes[0].grid(True, alpha=0.3, axis="y")
|
||||
|
||||
bp = axes[1].boxplot([all_errors[i] for i in range(6)], labels=names, patch_artist=True)
|
||||
for patch, color in zip(bp["boxes"], colors):
|
||||
patch.set_facecolor(color)
|
||||
patch.set_alpha(0.8)
|
||||
axes[1].set_ylabel("Absolute Error")
|
||||
axes[1].set_title(f"Error Distribution per Parameter ({n_samples} samples)", fontweight="bold")
|
||||
axes[1].grid(True, alpha=0.3, axis="y")
|
||||
|
||||
rel_err_pos = np.arange(6)
|
||||
bars = axes[2].bar(rel_err_pos, relative_errors, color=colors, alpha=0.85, edgecolor="white", linewidth=1.5)
|
||||
axes[2].set_xticks(rel_err_pos)
|
||||
axes[2].set_xticklabels(names)
|
||||
axes[2].set_ylabel("Relative Error (MAE / Range)")
|
||||
axes[2].set_title("Relative Error per Parameter", fontweight="bold")
|
||||
axes[2].grid(True, alpha=0.3, axis="y")
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig(os.path.join(IMG_DIR, "mae_boxplot.png"), dpi=150, bbox_inches="tight")
|
||||
print("Saved mae_boxplot.png")
|
||||
plt.show()
|
||||
|
||||
fig, axes = plt.subplots(1, 2, figsize=(14, 6))
|
||||
|
||||
x_pos = np.arange(6)
|
||||
axes[0].bar(x_pos, mean_errors, yerr=std_errors, capsize=5, color=["c", "m", "y", "g", "b", "r"], alpha=0.8)
|
||||
angle_names = ["rx", "ry", "rz"]
|
||||
x_pos = np.arange(3)
|
||||
colors_angle = ["#9b59b6", "#2ecc71", "#f39c12"]
|
||||
bars = axes[0].bar(x_pos, angle_errors_deg, color=colors_angle, alpha=0.85, edgecolor="white", linewidth=1.5)
|
||||
axes[0].set_xticks(x_pos)
|
||||
axes[0].set_xticklabels(names)
|
||||
axes[0].set_ylabel("Mean Absolute Error")
|
||||
axes[0].set_title("Mean Absolute Error per Parameter (with std)")
|
||||
axes[0].set_xticklabels(angle_names)
|
||||
axes[0].set_ylabel("Mean Absolute Error (degrees)")
|
||||
axes[0].set_title("Angle MAE in Degrees", fontweight="bold")
|
||||
axes[0].grid(True, alpha=0.3, axis="y")
|
||||
for i, e in enumerate(angle_errors_deg):
|
||||
axes[0].text(i, e + 0.5, f"{e:.1f}°", ha="center", va="bottom", fontsize=11, fontweight="bold")
|
||||
|
||||
bp = axes[1].boxplot([all_errors[i] for i in range(6)], labels=names, patch_artist=True)
|
||||
colors = ["c", "m", "y", "g", "b", "r"]
|
||||
for patch, color in zip(bp["boxes"], colors):
|
||||
patch.set_facecolor(color)
|
||||
patch.set_alpha(0.7)
|
||||
axes[1].set_ylabel("Absolute Error")
|
||||
axes[1].set_title(f"Error Distribution per Parameter ({n_samples} samples)")
|
||||
trans_scale_errs = [mean_errors[0], mean_errors[1], mean_errors[5]]
|
||||
trans_scale_names = ["tx", "ty", "scale"]
|
||||
x_pos = np.arange(3)
|
||||
colors_trans = ["#3498db", "#e74c3c", "#1abc9c"]
|
||||
bars = axes[1].bar(x_pos, trans_scale_errs, color=colors_trans, alpha=0.85, edgecolor="white", linewidth=1.5)
|
||||
axes[1].set_xticks(x_pos)
|
||||
axes[1].set_xticklabels(trans_scale_names)
|
||||
axes[1].set_ylabel("Mean Absolute Error")
|
||||
axes[1].set_title("Translation & Scale MAE", fontweight="bold")
|
||||
axes[1].grid(True, alpha=0.3, axis="y")
|
||||
for i, e in enumerate(trans_scale_errs):
|
||||
axes[1].text(i, e + 0.01, f"{e:.4f}", ha="center", va="bottom", fontsize=11, fontweight="bold")
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig(os.path.join(IMG_DIR, "mae_boxplot.png"), dpi=150)
|
||||
print("Saved mae_boxplot.png")
|
||||
plt.savefig(os.path.join(IMG_DIR, "mae_by_category.png"), dpi=150, bbox_inches="tight")
|
||||
print("Saved mae_by_category.png")
|
||||
plt.show()
|
||||
|
||||
print("\n=== Sample Predictions (20 pairs) ===")
|
||||
@@ -196,50 +247,58 @@ def analyze_training(trainer):
|
||||
fig, axes = plt.subplots(2, 2, figsize=(12, 10))
|
||||
|
||||
axes[0, 0].imshow(google_img[0].cpu().permute(1, 2, 0))
|
||||
axes[0, 0].set_title(f"Google Image")
|
||||
axes[0, 0].set_title("Google Image", fontweight="bold", fontsize=12)
|
||||
axes[0, 0].axis("off")
|
||||
|
||||
axes[0, 1].imshow(yandex_img[0].cpu().permute(1, 2, 0))
|
||||
axes[0, 1].set_title(f"Yandex Image")
|
||||
axes[0, 1].set_title("Yandex Image", fontweight="bold", fontsize=12)
|
||||
axes[0, 1].axis("off")
|
||||
|
||||
x_pos = np.arange(6)
|
||||
width = 0.35
|
||||
axes[1, 0].bar(x_pos - width/2, targets, width, label="Target", color="steelblue", alpha=0.8)
|
||||
axes[1, 0].bar(x_pos + width/2, preds, width, label="Predicted", color="coral", alpha=0.8)
|
||||
axes[1, 0].bar(x_pos - width/2, targets, width, label="Target", color="#3498db", alpha=0.85)
|
||||
axes[1, 0].bar(x_pos + width/2, preds, width, label="Predicted", color="#e74c3c", alpha=0.85)
|
||||
axes[1, 0].set_xticks(x_pos)
|
||||
axes[1, 0].set_xticklabels(names)
|
||||
axes[1, 0].set_ylabel("Parameter Value")
|
||||
axes[1, 0].set_title("Target vs Predicted")
|
||||
axes[1, 0].legend()
|
||||
axes[1, 0].set_title("Target vs Predicted", fontweight="bold", fontsize=12)
|
||||
axes[1, 0].legend(framealpha=0.9)
|
||||
axes[1, 0].grid(True, alpha=0.3, axis="y")
|
||||
|
||||
axes[1, 1].bar(x_pos, errors, color=["c", "m", "y", "g", "b", "r"], alpha=0.8)
|
||||
colors = ["#3498db", "#e74c3c", "#9b59b6", "#2ecc71", "#f39c12", "#1abc9c"]
|
||||
bars = axes[1, 1].bar(x_pos, errors, color=colors, alpha=0.85, edgecolor="white", linewidth=1.2)
|
||||
axes[1, 1].set_xticks(x_pos)
|
||||
axes[1, 1].set_xticklabels(names)
|
||||
axes[1, 1].set_ylabel("Absolute Error")
|
||||
axes[1, 1].set_title(f"Prediction Error (Mean: {np.mean(errors):.4f})")
|
||||
axes[1, 1].set_title(f"Prediction Error (Mean: {np.mean(errors):.4f})", fontweight="bold", fontsize=12)
|
||||
axes[1, 1].grid(True, alpha=0.3, axis="y")
|
||||
for i_e, e in enumerate(errors):
|
||||
axes[1, 1].text(i_e, e + 0.01, f"{e:.3f}", ha="center", va="bottom", fontsize=8)
|
||||
axes[1, 1].text(i_e, e + 0.01, f"{e:.3f}", ha="center", va="bottom", fontsize=9)
|
||||
|
||||
plt.suptitle(f"Sample {vis_count + 1}", fontsize=14)
|
||||
plt.suptitle(f"Sample {vis_count + 1}", fontsize=14, fontweight="bold")
|
||||
plt.tight_layout()
|
||||
plt.savefig(os.path.join(IMG_DIR, f"prediction_sample_{vis_count + 1:02d}.png"), dpi=100)
|
||||
plt.savefig(os.path.join(IMG_DIR, f"prediction_sample_{vis_count + 1:02d}.png"), dpi=100, bbox_inches="tight")
|
||||
plt.show()
|
||||
print(f"Saved prediction_sample_{vis_count + 1:02d}.png")
|
||||
|
||||
vis_count += 1
|
||||
|
||||
print(f"\nPrediction errors over {n_samples} samples:")
|
||||
print(f"{'Param':<8} {'Mean Error':>12} {'Std Error':>12} {'Min':>8} {'Max':>8}")
|
||||
print("-" * 52)
|
||||
print(f"{'Param':<8} {'Mean Error':>12} {'Std Error':>12} {'Min':>8} {'Max':>8} {'Rel Err':>10}")
|
||||
print("-" * 62)
|
||||
for i in range(6):
|
||||
mean_err = np.mean(all_errors[i])
|
||||
std_err = np.std(all_errors[i])
|
||||
min_err = np.min(all_errors[i])
|
||||
max_err = np.max(all_errors[i])
|
||||
print(f"{names[i]:<8} {mean_err:>12.4f} {std_err:>12.4f} {min_err:>8.4f} {max_err:>8.4f}")
|
||||
rel_err = relative_errors[i]
|
||||
print(f"{names[i]:<8} {mean_err:>12.4f} {std_err:>12.4f} {min_err:>8.4f} {max_err:>8.4f} {rel_err:>10.4f}")
|
||||
|
||||
print(f"\nAngle errors in degrees:")
|
||||
print(f"{'Param':<8} {'MAE (deg)':>12} {'MAE (rad)':>12}")
|
||||
print("-" * 35)
|
||||
for i, name in enumerate(["rx", "ry", "rz"]):
|
||||
print(f"{name:<8} {angle_errors_deg[i]:>12.2f} {mean_errors[i+2]:>12.4f}")
|
||||
|
||||
return {
|
||||
"best_val_loss": trainer.best_val_loss,
|
||||
@@ -248,4 +307,8 @@ def analyze_training(trainer):
|
||||
"val_mse_trans": trainer.val_mse_trans,
|
||||
"val_mse_angle": trainer.val_mse_angle,
|
||||
"val_mse_scale": trainer.val_mse_scale,
|
||||
"mean_errors": mean_errors,
|
||||
"std_errors": std_errors,
|
||||
"angle_errors_deg": angle_errors_deg,
|
||||
"relative_errors": relative_errors,
|
||||
}
|
||||
|
||||
@@ -9,7 +9,7 @@ from PIL import Image
|
||||
from torch.utils.data import DataLoader, Dataset, Subset
|
||||
from torchvision import transforms
|
||||
|
||||
from .utils import config, get_camera_matrix, generate_random_homography_params, homography_params_to_matrix, matrix_to_homography_params
|
||||
from .utils import config, get_camera_matrix, generate_random_homography_params, homography_params_to_matrix
|
||||
|
||||
|
||||
class YaGoDataset(Dataset):
|
||||
@@ -55,6 +55,7 @@ class YaGoDataset(Dataset):
|
||||
self._cached_google = [None] * len(self.image_pairs)
|
||||
self._cached_yandex = [None] * len(self.image_pairs)
|
||||
self._cached_homography = [None] * len(self.image_pairs)
|
||||
self._cached_params = [None] * len(self.image_pairs)
|
||||
|
||||
def _generate_augmented(self, idx):
|
||||
google_img = self._google_images[idx].copy()
|
||||
@@ -64,14 +65,11 @@ class YaGoDataset(Dataset):
|
||||
params2 = generate_random_homography_params()
|
||||
H1 = homography_params_to_matrix(params1, self.K)
|
||||
H2 = homography_params_to_matrix(params2, self.K)
|
||||
H_combined = np.linalg.inv(H1) @ H2
|
||||
|
||||
google_warped = cv2.warpPerspective(google_img, H2, (self.image_size[1], self.image_size[0]))
|
||||
yandex_warped = cv2.warpPerspective(yandex_img, H1, (self.image_size[1], self.image_size[0]))
|
||||
google_warped = cv2.warpPerspective(google_img, H2 @ H1, (self.image_size[1], self.image_size[0]))
|
||||
|
||||
target_params = matrix_to_homography_params(H_combined, self.K)
|
||||
|
||||
return google_warped, yandex_warped, H_combined, target_params
|
||||
return google_warped, yandex_warped, H2, params2
|
||||
|
||||
def __len__(self):
|
||||
return len(self.image_pairs)
|
||||
@@ -85,13 +83,14 @@ class YaGoDataset(Dataset):
|
||||
google_img = self._cached_google[idx]
|
||||
yandex_img = self._cached_yandex[idx]
|
||||
target_matrix = self._cached_homography[idx]
|
||||
target_params = matrix_to_homography_params(target_matrix, self.K)
|
||||
target_params = self._cached_params[idx]
|
||||
elif self.augment:
|
||||
google_img, yandex_img, target_matrix, target_params = self._generate_augmented(idx)
|
||||
if self.cache_level > 0:
|
||||
self._cached_google[idx] = google_img
|
||||
self._cached_yandex[idx] = yandex_img
|
||||
self._cached_homography[idx] = target_matrix
|
||||
self._cached_params[idx] = target_params
|
||||
else:
|
||||
google_img = self._google_images[idx]
|
||||
yandex_img = self._yandex_images[idx]
|
||||
|
||||
@@ -29,11 +29,19 @@ class HomographyCNN6(nn.Module):
|
||||
nn.Dropout(dropout_rate),
|
||||
nn.Linear(256, 6),
|
||||
)
|
||||
self._init_weights()
|
||||
|
||||
def _normalize_sin_cos(self, _sin, _cos):
|
||||
_len = torch.sqrt(_sin ** 2 + _cos ** 2)
|
||||
return _sin / _len, _cos / _len
|
||||
|
||||
def _init_weights(self):
|
||||
for module in self.head.modules():
|
||||
if isinstance(module, nn.Linear):
|
||||
nn.init.kaiming_normal_(module.weight, mode='fan_in', nonlinearity='relu')
|
||||
if module.bias is not None:
|
||||
nn.init.zeros_(module.bias)
|
||||
|
||||
def forward(self, img1, img2):
|
||||
f1 = self.backbone(img1)
|
||||
f2 = self.backbone(img2)
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
|
||||
@@ -43,13 +44,16 @@ def homography_params_to_matrix(params, K):
|
||||
|
||||
|
||||
def matrix_to_homography_params(H, K):
|
||||
if hasattr(H, 'numpy'):
|
||||
H = H.numpy()
|
||||
K_inv = np.linalg.inv(K)
|
||||
E = K_inv @ H @ K
|
||||
scale = np.sqrt(np.linalg.det(E[:2, :2]))
|
||||
R = E[:2, :2] / scale
|
||||
tx, ty = E[0, 2], E[1, 2]
|
||||
rz = np.arctan2(R[1, 0], R[0, 0])
|
||||
r20, r21 = E[2, 0], E[2, 1]
|
||||
ry = np.arctan2(r20, r21)
|
||||
rx = np.arctan2(-E[1, 2], E[1, 1])
|
||||
scale = E[2, 2]
|
||||
R_normalized = E / scale
|
||||
rz = np.arctan2(R_normalized[1, 0], R_normalized[0, 0])
|
||||
ry = np.arctan2(-R_normalized[2, 0], np.sqrt(R_normalized[2, 1]**2 + R_normalized[2, 2]**2))
|
||||
rx = np.arctan2(R_normalized[2, 1], R_normalized[2, 2])
|
||||
A = R_normalized[:2, :2]
|
||||
correction = scale * np.array([R_normalized[0, 2], R_normalized[1, 2]])
|
||||
tx, ty = np.linalg.solve(A, E[:2, 2] - correction)
|
||||
return np.array([tx, ty, rx, ry, rz, scale], dtype=np.float32)
|
||||
|
||||
Reference in New Issue
Block a user