improve model

This commit is contained in:
2026-04-06 17:36:41 +03:00
parent b4b8f78970
commit 3372ee4055
4 changed files with 90 additions and 78 deletions

View File

@@ -43,6 +43,9 @@ Google/Yandex image pair loader with homography augmentation.
- Batch dict with `google_img`, `yandex_img`, `homography_params`"""
# code: ./src/dataloader.py
# code: ./src/test_dataloader.py
# markdown
"""## Model

View File

@@ -266,7 +266,7 @@
" self.backbone = backbone\n",
"\n",
" self.head = nn.Sequential(\n",
" nn.Linear(self.feature_dim * 4, 512),\n",
" nn.Linear(self.feature_dim * 4, 1024),\n",
" nn.ReLU(inplace=True),\n",
" nn.Dropout(dropout_rate),\n",
" nn.Linear(1024, 512),\n",
@@ -275,7 +275,7 @@
" nn.Linear(512, 256),\n",
" nn.ReLU(inplace=True),\n",
" nn.Dropout(dropout_rate),\n",
" nn.Linear(512, 9),\n",
" nn.Linear(256, 6),\n",
" )\n",
"\n",
" def _normalize_sin_cos(self, _sin, _cos):\n",
@@ -287,21 +287,21 @@
" f2 = self.backbone(img2)\n",
" combined = torch.cat([f1, f2, torch.abs(f1 - f2), f1 * f2], dim=1)\n",
"\n",
" combined[:, (0, 1)] = torch.tanh(combined[:, (0, 1)]) * 10 # [-10; 10]\n",
" combined[:, (2, 3)] = self._normalize_sin_cos(torch.tanh(combined[:, 2]), torch.tanh(combined[:, 3]))\n",
" combined[:, (4, 5)] = self._normalize_sin_cos(torch.tanh(combined[:, 4]), torch.tanh(combined[:, 5]))\n",
" combined[:, (6, 7)] = self._normalize_sin_cos(torch.tanh(combined[:, 6]), torch.tanh(combined[:, 7]))\n",
" output = self.head(combined)\n",
"\n",
" return self.head(combined)\n",
" output = torch.tanh(output) # [-1; 1]\n",
" modified = output.clone()\n",
" modified[:, 2:5] = torch.mul(output[:, 2:5], torch.pi) # [-pi; pi]\n",
"\n",
" return modified\n",
"\n",
" def decode_output(self, output):\n",
" tx = output[:, 0]\n",
" ty = output[:, 1]\n",
" scale = output[:, 8]\n",
"\n",
" angle1 = torch.atan2(output[:, 2], output[:, 3])\n",
" angle2 = torch.atan2(output[:, 4], output[:, 5])\n",
" angle3 = torch.atan2(output[:, 6], output[:, 7])\n",
" scale = output[:, 5]\n",
" angle1 = output[:, 2]\n",
" angle2 = output[:, 3]\n",
" angle3 = output[:, 4]\n",
"\n",
" return torch.stack([tx, ty, angle1, angle2, angle3, scale], dim=1)\n",
"\n",
@@ -325,33 +325,27 @@
" self.trans_loss_weight = trans_loss_weight\n",
" self.scale_loss_weight = scale_loss_weight\n",
"\n",
" @staticmethod\n",
" def dot_angles(src, dest):\n",
" sin_src = torch.sin(src)\n",
" cos_src = torch.cos(src)\n",
" sin_dest = torch.sin(dest)\n",
" cos_dest = torch.cos(dest)\n",
" return sin_src * sin_dest + cos_src * cos_dest\n",
"\n",
" def forward(self, pred, target):\n",
" tx_loss = self.criterion(pred[:, 0], target[:, 0])\n",
" ty_loss = self.criterion(pred[:, 1], target[:, 1])\n",
"\n",
" sin_rx_pred = pred[:, 2]\n",
" cos_rx_pred = pred[:, 3]\n",
" sin_ry_pred = pred[:, 4]\n",
" cos_ry_pred = pred[:, 5]\n",
" sin_rz_pred = pred[:, 6]\n",
" cos_rz_pred = pred[:, 7]\n",
" dot_rx = HomographyLoss6.dot_angles(pred[:, 2], target[:, 2])\n",
" dot_ry = HomographyLoss6.dot_angles(pred[:, 3], target[:, 3])\n",
" dot_rz = HomographyLoss6.dot_angles(pred[:, 4], target[:, 4])\n",
"\n",
" sin_rx_target = torch.sin(target[:, 2])\n",
" cos_rx_target = torch.cos(target[:, 2])\n",
" sin_ry_target = torch.sin(target[:, 3])\n",
" cos_ry_target = torch.cos(target[:, 3])\n",
" sin_rz_target = torch.sin(target[:, 4])\n",
" cos_rz_target = torch.cos(target[:, 4])\n",
" rx_loss = self.criterion(dot_rx, torch.ones_like(dot_rx))\n",
" ry_loss = self.criterion(dot_ry, torch.ones_like(dot_ry))\n",
" rz_loss = self.criterion(dot_rz, torch.ones_like(dot_rz))\n",
"\n",
" dot_rx = sin_rx_pred * sin_rx_target + cos_rx_pred * cos_rx_target\n",
" dot_ry = sin_ry_pred * sin_ry_target + cos_ry_pred * cos_ry_target\n",
" dot_rz = sin_rz_pred * sin_rz_target + cos_rz_pred * cos_rz_target\n",
"\n",
" rx_loss = (1 - dot_rx).mean()\n",
" ry_loss = (1 - dot_ry).mean()\n",
" rz_loss = (1 - dot_rz).mean()\n",
"\n",
" scale_loss = self.criterion(pred[:, 8], target[:, 5])\n",
" scale_loss = self.criterion(pred[:, 5], target[:, 5])\n",
"\n",
" total_loss = (\n",
" self.trans_loss_weight * (tx_loss + ty_loss) +\n",
@@ -361,14 +355,17 @@
"\n",
" return total_loss\n",
"\n",
" def compute_mse_components(self, pred, target):\n",
" decoded = self.decode_output(pred)\n",
" def compute_mse_components(self, decoded, target):\n",
" tx_mse = self.criterion(decoded[:, 0], target[:, 0]).item()\n",
" ty_mse = self.criterion(decoded[:, 1], target[:, 1]).item()\n",
" \n",
" rx_mse = angular_difference(decoded[:, 2], target[:, 2]).item()\n",
" ry_mse = angular_difference(decoded[:, 3], target[:, 3]).item()\n",
" rz_mse = angular_difference(decoded[:, 4], target[:, 4]).item()\n",
" dot_rx = HomographyLoss6.dot_angles(decoded[:, 2], target[:, 2])\n",
" dot_ry = HomographyLoss6.dot_angles(decoded[:, 3], target[:, 3])\n",
" dot_rz = HomographyLoss6.dot_angles(decoded[:, 4], target[:, 4])\n",
"\n",
" rx_mse = self.criterion(dot_rx, torch.ones_like(dot_rx)).item()\n",
" ry_mse = self.criterion(dot_ry, torch.ones_like(dot_ry)).item()\n",
" rz_mse = self.criterion(dot_rz, torch.ones_like(dot_rz)).item()\n",
"\n",
" scale_mse = self.criterion(decoded[:, 5], target[:, 5]).item()\n",
"\n",

View File

@@ -18,7 +18,7 @@ class HomographyCNN6(nn.Module):
self.backbone = backbone
self.head = nn.Sequential(
nn.Linear(self.feature_dim * 4, 512),
nn.Linear(self.feature_dim * 4, 1024),
nn.ReLU(inplace=True),
nn.Dropout(dropout_rate),
nn.Linear(1024, 512),
@@ -27,7 +27,7 @@ class HomographyCNN6(nn.Module):
nn.Linear(512, 256),
nn.ReLU(inplace=True),
nn.Dropout(dropout_rate),
nn.Linear(512, 9),
nn.Linear(256, 6),
)
def _normalize_sin_cos(self, _sin, _cos):
@@ -39,21 +39,21 @@ class HomographyCNN6(nn.Module):
f2 = self.backbone(img2)
combined = torch.cat([f1, f2, torch.abs(f1 - f2), f1 * f2], dim=1)
combined[:, (0, 1)] = torch.tanh(combined[:, (0, 1)]) * 10 # [-10; 10]
combined[:, (2, 3)] = self._normalize_sin_cos(torch.tanh(combined[:, 2]), torch.tanh(combined[:, 3]))
combined[:, (4, 5)] = self._normalize_sin_cos(torch.tanh(combined[:, 4]), torch.tanh(combined[:, 5]))
combined[:, (6, 7)] = self._normalize_sin_cos(torch.tanh(combined[:, 6]), torch.tanh(combined[:, 7]))
output = self.head(combined)
return self.head(combined)
output = torch.tanh(output) # [-1; 1]
modified = output.clone()
modified[:, 2:6] = torch.mul(output[:, 2:6], torch.pi) # [-pi; pi]
return modified
def decode_output(self, output):
tx = output[:, 0]
ty = output[:, 1]
scale = output[:, 8]
angle1 = torch.atan2(output[:, 2], output[:, 3])
angle2 = torch.atan2(output[:, 4], output[:, 5])
angle3 = torch.atan2(output[:, 6], output[:, 7])
scale = output[:, 5]
angle1 = output[:, 2]
angle2 = output[:, 3]
angle3 = output[:, 4]
return torch.stack([tx, ty, angle1, angle2, angle3, scale], dim=1)
@@ -77,33 +77,27 @@ class HomographyLoss6(nn.Module):
self.trans_loss_weight = trans_loss_weight
self.scale_loss_weight = scale_loss_weight
@staticmethod
def dot_angles(src, dest):
sin_src = torch.sin(src)
cos_src = torch.cos(src)
sin_dest = torch.sin(dest)
cos_dest = torch.cos(dest)
return sin_src * sin_dest + cos_src * cos_dest
def forward(self, pred, target):
tx_loss = self.criterion(pred[:, 0], target[:, 0])
ty_loss = self.criterion(pred[:, 1], target[:, 1])
sin_rx_pred = pred[:, 2]
cos_rx_pred = pred[:, 3]
sin_ry_pred = pred[:, 4]
cos_ry_pred = pred[:, 5]
sin_rz_pred = pred[:, 6]
cos_rz_pred = pred[:, 7]
dot_rx = HomographyLoss6.dot_angles(pred[:, 2], target[:, 2])
dot_ry = HomographyLoss6.dot_angles(pred[:, 3], target[:, 3])
dot_rz = HomographyLoss6.dot_angles(pred[:, 4], target[:, 4])
sin_rx_target = torch.sin(target[:, 2])
cos_rx_target = torch.cos(target[:, 2])
sin_ry_target = torch.sin(target[:, 3])
cos_ry_target = torch.cos(target[:, 3])
sin_rz_target = torch.sin(target[:, 4])
cos_rz_target = torch.cos(target[:, 4])
rx_loss = self.criterion(dot_rx, torch.ones_like(dot_rx))
ry_loss = self.criterion(dot_ry, torch.ones_like(dot_ry))
rz_loss = self.criterion(dot_rz, torch.ones_like(dot_rz))
dot_rx = sin_rx_pred * sin_rx_target + cos_rx_pred * cos_rx_target
dot_ry = sin_ry_pred * sin_ry_target + cos_ry_pred * cos_ry_target
dot_rz = sin_rz_pred * sin_rz_target + cos_rz_pred * cos_rz_target
rx_loss = (1 - dot_rx).mean()
ry_loss = (1 - dot_ry).mean()
rz_loss = (1 - dot_rz).mean()
scale_loss = self.criterion(pred[:, 8], target[:, 5])
scale_loss = self.criterion(pred[:, 5], target[:, 5])
total_loss = (
self.trans_loss_weight * (tx_loss + ty_loss) +
@@ -113,14 +107,17 @@ class HomographyLoss6(nn.Module):
return total_loss
def compute_mse_components(self, pred, target):
decoded = self.decode_output(pred)
def compute_mse_components(self, decoded, target):
tx_mse = self.criterion(decoded[:, 0], target[:, 0]).item()
ty_mse = self.criterion(decoded[:, 1], target[:, 1]).item()
rx_mse = angular_difference(decoded[:, 2], target[:, 2]).item()
ry_mse = angular_difference(decoded[:, 3], target[:, 3]).item()
rz_mse = angular_difference(decoded[:, 4], target[:, 4]).item()
dot_rx = HomographyLoss6.dot_angles(decoded[:, 2], target[:, 2])
dot_ry = HomographyLoss6.dot_angles(decoded[:, 3], target[:, 3])
dot_rz = HomographyLoss6.dot_angles(decoded[:, 4], target[:, 4])
rx_mse = self.criterion(dot_rx, torch.ones_like(dot_rx)).item()
ry_mse = self.criterion(dot_ry, torch.ones_like(dot_ry)).item()
rz_mse = self.criterion(dot_rz, torch.ones_like(dot_rz)).item()
scale_mse = self.criterion(decoded[:, 5], target[:, 5]).item()

View File

@@ -0,0 +1,15 @@
from ..src.dataloader import *
train_loader, val_loader = create_data_loaders(config['data_dir'])
batch = next(iter(train_loader))
google_img = batch['google_img'][0]
yandex_img = batch['yandex_img'][0]
# google_img.permute((1, 2, 0)) * 255
batch['homography_params'].mean(axis=0)
print(batch['homography_matrix'][0])
print(batch['homography_params'][0])
K = get_camera_matrix(config['image_size'][0], config['image_size'][1])
print(homography_params_to_matrix(batch['homography_params'][0], K))
print(matrix_to_homography_params(batch['homography_matrix'][0].numpy(), K))