improve model
This commit is contained in:
@@ -43,6 +43,9 @@ Google/Yandex image pair loader with homography augmentation.
|
||||
- Batch dict with `google_img`, `yandex_img`, `homography_params`"""
|
||||
# code: ./src/dataloader.py
|
||||
|
||||
# code: ./src/test_dataloader.py
|
||||
|
||||
|
||||
# markdown
|
||||
"""## Model
|
||||
|
||||
|
||||
@@ -266,7 +266,7 @@
|
||||
" self.backbone = backbone\n",
|
||||
"\n",
|
||||
" self.head = nn.Sequential(\n",
|
||||
" nn.Linear(self.feature_dim * 4, 512),\n",
|
||||
" nn.Linear(self.feature_dim * 4, 1024),\n",
|
||||
" nn.ReLU(inplace=True),\n",
|
||||
" nn.Dropout(dropout_rate),\n",
|
||||
" nn.Linear(1024, 512),\n",
|
||||
@@ -275,7 +275,7 @@
|
||||
" nn.Linear(512, 256),\n",
|
||||
" nn.ReLU(inplace=True),\n",
|
||||
" nn.Dropout(dropout_rate),\n",
|
||||
" nn.Linear(512, 9),\n",
|
||||
" nn.Linear(256, 6),\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" def _normalize_sin_cos(self, _sin, _cos):\n",
|
||||
@@ -287,21 +287,21 @@
|
||||
" f2 = self.backbone(img2)\n",
|
||||
" combined = torch.cat([f1, f2, torch.abs(f1 - f2), f1 * f2], dim=1)\n",
|
||||
"\n",
|
||||
" combined[:, (0, 1)] = torch.tanh(combined[:, (0, 1)]) * 10 # [-10; 10]\n",
|
||||
" combined[:, (2, 3)] = self._normalize_sin_cos(torch.tanh(combined[:, 2]), torch.tanh(combined[:, 3]))\n",
|
||||
" combined[:, (4, 5)] = self._normalize_sin_cos(torch.tanh(combined[:, 4]), torch.tanh(combined[:, 5]))\n",
|
||||
" combined[:, (6, 7)] = self._normalize_sin_cos(torch.tanh(combined[:, 6]), torch.tanh(combined[:, 7]))\n",
|
||||
" output = self.head(combined)\n",
|
||||
"\n",
|
||||
" return self.head(combined)\n",
|
||||
" output = torch.tanh(output) # [-1; 1]\n",
|
||||
" modified = output.clone()\n",
|
||||
" modified[:, 2:5] = torch.mul(output[:, 2:5], torch.pi) # [-pi; pi]\n",
|
||||
"\n",
|
||||
" return modified\n",
|
||||
"\n",
|
||||
" def decode_output(self, output):\n",
|
||||
" tx = output[:, 0]\n",
|
||||
" ty = output[:, 1]\n",
|
||||
" scale = output[:, 8]\n",
|
||||
"\n",
|
||||
" angle1 = torch.atan2(output[:, 2], output[:, 3])\n",
|
||||
" angle2 = torch.atan2(output[:, 4], output[:, 5])\n",
|
||||
" angle3 = torch.atan2(output[:, 6], output[:, 7])\n",
|
||||
" scale = output[:, 5]\n",
|
||||
" angle1 = output[:, 2]\n",
|
||||
" angle2 = output[:, 3]\n",
|
||||
" angle3 = output[:, 4]\n",
|
||||
"\n",
|
||||
" return torch.stack([tx, ty, angle1, angle2, angle3, scale], dim=1)\n",
|
||||
"\n",
|
||||
@@ -325,33 +325,27 @@
|
||||
" self.trans_loss_weight = trans_loss_weight\n",
|
||||
" self.scale_loss_weight = scale_loss_weight\n",
|
||||
"\n",
|
||||
" @staticmethod\n",
|
||||
" def dot_angles(src, dest):\n",
|
||||
" sin_src = torch.sin(src)\n",
|
||||
" cos_src = torch.cos(src)\n",
|
||||
" sin_dest = torch.sin(dest)\n",
|
||||
" cos_dest = torch.cos(dest)\n",
|
||||
" return sin_src * sin_dest + cos_src * cos_dest\n",
|
||||
"\n",
|
||||
" def forward(self, pred, target):\n",
|
||||
" tx_loss = self.criterion(pred[:, 0], target[:, 0])\n",
|
||||
" ty_loss = self.criterion(pred[:, 1], target[:, 1])\n",
|
||||
"\n",
|
||||
" sin_rx_pred = pred[:, 2]\n",
|
||||
" cos_rx_pred = pred[:, 3]\n",
|
||||
" sin_ry_pred = pred[:, 4]\n",
|
||||
" cos_ry_pred = pred[:, 5]\n",
|
||||
" sin_rz_pred = pred[:, 6]\n",
|
||||
" cos_rz_pred = pred[:, 7]\n",
|
||||
" dot_rx = HomographyLoss6.dot_angles(pred[:, 2], target[:, 2])\n",
|
||||
" dot_ry = HomographyLoss6.dot_angles(pred[:, 3], target[:, 3])\n",
|
||||
" dot_rz = HomographyLoss6.dot_angles(pred[:, 4], target[:, 4])\n",
|
||||
"\n",
|
||||
" sin_rx_target = torch.sin(target[:, 2])\n",
|
||||
" cos_rx_target = torch.cos(target[:, 2])\n",
|
||||
" sin_ry_target = torch.sin(target[:, 3])\n",
|
||||
" cos_ry_target = torch.cos(target[:, 3])\n",
|
||||
" sin_rz_target = torch.sin(target[:, 4])\n",
|
||||
" cos_rz_target = torch.cos(target[:, 4])\n",
|
||||
" rx_loss = self.criterion(dot_rx, torch.ones_like(dot_rx))\n",
|
||||
" ry_loss = self.criterion(dot_ry, torch.ones_like(dot_ry))\n",
|
||||
" rz_loss = self.criterion(dot_rz, torch.ones_like(dot_rz))\n",
|
||||
"\n",
|
||||
" dot_rx = sin_rx_pred * sin_rx_target + cos_rx_pred * cos_rx_target\n",
|
||||
" dot_ry = sin_ry_pred * sin_ry_target + cos_ry_pred * cos_ry_target\n",
|
||||
" dot_rz = sin_rz_pred * sin_rz_target + cos_rz_pred * cos_rz_target\n",
|
||||
"\n",
|
||||
" rx_loss = (1 - dot_rx).mean()\n",
|
||||
" ry_loss = (1 - dot_ry).mean()\n",
|
||||
" rz_loss = (1 - dot_rz).mean()\n",
|
||||
"\n",
|
||||
" scale_loss = self.criterion(pred[:, 8], target[:, 5])\n",
|
||||
" scale_loss = self.criterion(pred[:, 5], target[:, 5])\n",
|
||||
"\n",
|
||||
" total_loss = (\n",
|
||||
" self.trans_loss_weight * (tx_loss + ty_loss) +\n",
|
||||
@@ -361,14 +355,17 @@
|
||||
"\n",
|
||||
" return total_loss\n",
|
||||
"\n",
|
||||
" def compute_mse_components(self, pred, target):\n",
|
||||
" decoded = self.decode_output(pred)\n",
|
||||
" def compute_mse_components(self, decoded, target):\n",
|
||||
" tx_mse = self.criterion(decoded[:, 0], target[:, 0]).item()\n",
|
||||
" ty_mse = self.criterion(decoded[:, 1], target[:, 1]).item()\n",
|
||||
" \n",
|
||||
" rx_mse = angular_difference(decoded[:, 2], target[:, 2]).item()\n",
|
||||
" ry_mse = angular_difference(decoded[:, 3], target[:, 3]).item()\n",
|
||||
" rz_mse = angular_difference(decoded[:, 4], target[:, 4]).item()\n",
|
||||
" dot_rx = HomographyLoss6.dot_angles(decoded[:, 2], target[:, 2])\n",
|
||||
" dot_ry = HomographyLoss6.dot_angles(decoded[:, 3], target[:, 3])\n",
|
||||
" dot_rz = HomographyLoss6.dot_angles(decoded[:, 4], target[:, 4])\n",
|
||||
"\n",
|
||||
" rx_mse = self.criterion(dot_rx, torch.ones_like(dot_rx)).item()\n",
|
||||
" ry_mse = self.criterion(dot_ry, torch.ones_like(dot_ry)).item()\n",
|
||||
" rz_mse = self.criterion(dot_rz, torch.ones_like(dot_rz)).item()\n",
|
||||
"\n",
|
||||
" scale_mse = self.criterion(decoded[:, 5], target[:, 5]).item()\n",
|
||||
"\n",
|
||||
|
||||
@@ -18,7 +18,7 @@ class HomographyCNN6(nn.Module):
|
||||
self.backbone = backbone
|
||||
|
||||
self.head = nn.Sequential(
|
||||
nn.Linear(self.feature_dim * 4, 512),
|
||||
nn.Linear(self.feature_dim * 4, 1024),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Dropout(dropout_rate),
|
||||
nn.Linear(1024, 512),
|
||||
@@ -27,7 +27,7 @@ class HomographyCNN6(nn.Module):
|
||||
nn.Linear(512, 256),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Dropout(dropout_rate),
|
||||
nn.Linear(512, 9),
|
||||
nn.Linear(256, 6),
|
||||
)
|
||||
|
||||
def _normalize_sin_cos(self, _sin, _cos):
|
||||
@@ -39,21 +39,21 @@ class HomographyCNN6(nn.Module):
|
||||
f2 = self.backbone(img2)
|
||||
combined = torch.cat([f1, f2, torch.abs(f1 - f2), f1 * f2], dim=1)
|
||||
|
||||
combined[:, (0, 1)] = torch.tanh(combined[:, (0, 1)]) * 10 # [-10; 10]
|
||||
combined[:, (2, 3)] = self._normalize_sin_cos(torch.tanh(combined[:, 2]), torch.tanh(combined[:, 3]))
|
||||
combined[:, (4, 5)] = self._normalize_sin_cos(torch.tanh(combined[:, 4]), torch.tanh(combined[:, 5]))
|
||||
combined[:, (6, 7)] = self._normalize_sin_cos(torch.tanh(combined[:, 6]), torch.tanh(combined[:, 7]))
|
||||
output = self.head(combined)
|
||||
|
||||
return self.head(combined)
|
||||
output = torch.tanh(output) # [-1; 1]
|
||||
modified = output.clone()
|
||||
modified[:, 2:6] = torch.mul(output[:, 2:6], torch.pi) # [-pi; pi]
|
||||
|
||||
return modified
|
||||
|
||||
def decode_output(self, output):
|
||||
tx = output[:, 0]
|
||||
ty = output[:, 1]
|
||||
scale = output[:, 8]
|
||||
|
||||
angle1 = torch.atan2(output[:, 2], output[:, 3])
|
||||
angle2 = torch.atan2(output[:, 4], output[:, 5])
|
||||
angle3 = torch.atan2(output[:, 6], output[:, 7])
|
||||
scale = output[:, 5]
|
||||
angle1 = output[:, 2]
|
||||
angle2 = output[:, 3]
|
||||
angle3 = output[:, 4]
|
||||
|
||||
return torch.stack([tx, ty, angle1, angle2, angle3, scale], dim=1)
|
||||
|
||||
@@ -77,33 +77,27 @@ class HomographyLoss6(nn.Module):
|
||||
self.trans_loss_weight = trans_loss_weight
|
||||
self.scale_loss_weight = scale_loss_weight
|
||||
|
||||
@staticmethod
|
||||
def dot_angles(src, dest):
|
||||
sin_src = torch.sin(src)
|
||||
cos_src = torch.cos(src)
|
||||
sin_dest = torch.sin(dest)
|
||||
cos_dest = torch.cos(dest)
|
||||
return sin_src * sin_dest + cos_src * cos_dest
|
||||
|
||||
def forward(self, pred, target):
|
||||
tx_loss = self.criterion(pred[:, 0], target[:, 0])
|
||||
ty_loss = self.criterion(pred[:, 1], target[:, 1])
|
||||
|
||||
sin_rx_pred = pred[:, 2]
|
||||
cos_rx_pred = pred[:, 3]
|
||||
sin_ry_pred = pred[:, 4]
|
||||
cos_ry_pred = pred[:, 5]
|
||||
sin_rz_pred = pred[:, 6]
|
||||
cos_rz_pred = pred[:, 7]
|
||||
dot_rx = HomographyLoss6.dot_angles(pred[:, 2], target[:, 2])
|
||||
dot_ry = HomographyLoss6.dot_angles(pred[:, 3], target[:, 3])
|
||||
dot_rz = HomographyLoss6.dot_angles(pred[:, 4], target[:, 4])
|
||||
|
||||
sin_rx_target = torch.sin(target[:, 2])
|
||||
cos_rx_target = torch.cos(target[:, 2])
|
||||
sin_ry_target = torch.sin(target[:, 3])
|
||||
cos_ry_target = torch.cos(target[:, 3])
|
||||
sin_rz_target = torch.sin(target[:, 4])
|
||||
cos_rz_target = torch.cos(target[:, 4])
|
||||
rx_loss = self.criterion(dot_rx, torch.ones_like(dot_rx))
|
||||
ry_loss = self.criterion(dot_ry, torch.ones_like(dot_ry))
|
||||
rz_loss = self.criterion(dot_rz, torch.ones_like(dot_rz))
|
||||
|
||||
dot_rx = sin_rx_pred * sin_rx_target + cos_rx_pred * cos_rx_target
|
||||
dot_ry = sin_ry_pred * sin_ry_target + cos_ry_pred * cos_ry_target
|
||||
dot_rz = sin_rz_pred * sin_rz_target + cos_rz_pred * cos_rz_target
|
||||
|
||||
rx_loss = (1 - dot_rx).mean()
|
||||
ry_loss = (1 - dot_ry).mean()
|
||||
rz_loss = (1 - dot_rz).mean()
|
||||
|
||||
scale_loss = self.criterion(pred[:, 8], target[:, 5])
|
||||
scale_loss = self.criterion(pred[:, 5], target[:, 5])
|
||||
|
||||
total_loss = (
|
||||
self.trans_loss_weight * (tx_loss + ty_loss) +
|
||||
@@ -113,14 +107,17 @@ class HomographyLoss6(nn.Module):
|
||||
|
||||
return total_loss
|
||||
|
||||
def compute_mse_components(self, pred, target):
|
||||
decoded = self.decode_output(pred)
|
||||
def compute_mse_components(self, decoded, target):
|
||||
tx_mse = self.criterion(decoded[:, 0], target[:, 0]).item()
|
||||
ty_mse = self.criterion(decoded[:, 1], target[:, 1]).item()
|
||||
|
||||
rx_mse = angular_difference(decoded[:, 2], target[:, 2]).item()
|
||||
ry_mse = angular_difference(decoded[:, 3], target[:, 3]).item()
|
||||
rz_mse = angular_difference(decoded[:, 4], target[:, 4]).item()
|
||||
dot_rx = HomographyLoss6.dot_angles(decoded[:, 2], target[:, 2])
|
||||
dot_ry = HomographyLoss6.dot_angles(decoded[:, 3], target[:, 3])
|
||||
dot_rz = HomographyLoss6.dot_angles(decoded[:, 4], target[:, 4])
|
||||
|
||||
rx_mse = self.criterion(dot_rx, torch.ones_like(dot_rx)).item()
|
||||
ry_mse = self.criterion(dot_ry, torch.ones_like(dot_ry)).item()
|
||||
rz_mse = self.criterion(dot_rz, torch.ones_like(dot_rz)).item()
|
||||
|
||||
scale_mse = self.criterion(decoded[:, 5], target[:, 5]).item()
|
||||
|
||||
|
||||
15
models/SiaN/src/test_dataloader.py
Normal file
15
models/SiaN/src/test_dataloader.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from ..src.dataloader import *
|
||||
|
||||
train_loader, val_loader = create_data_loaders(config['data_dir'])
|
||||
batch = next(iter(train_loader))
|
||||
google_img = batch['google_img'][0]
|
||||
yandex_img = batch['yandex_img'][0]
|
||||
|
||||
# google_img.permute((1, 2, 0)) * 255
|
||||
batch['homography_params'].mean(axis=0)
|
||||
|
||||
print(batch['homography_matrix'][0])
|
||||
print(batch['homography_params'][0])
|
||||
K = get_camera_matrix(config['image_size'][0], config['image_size'][1])
|
||||
print(homography_params_to_matrix(batch['homography_params'][0], K))
|
||||
print(matrix_to_homography_params(batch['homography_matrix'][0].numpy(), K))
|
||||
Reference in New Issue
Block a user