improve model
This commit is contained in:
@@ -43,6 +43,9 @@ Google/Yandex image pair loader with homography augmentation.
|
|||||||
- Batch dict with `google_img`, `yandex_img`, `homography_params`"""
|
- Batch dict with `google_img`, `yandex_img`, `homography_params`"""
|
||||||
# code: ./src/dataloader.py
|
# code: ./src/dataloader.py
|
||||||
|
|
||||||
|
# code: ./src/test_dataloader.py
|
||||||
|
|
||||||
|
|
||||||
# markdown
|
# markdown
|
||||||
"""## Model
|
"""## Model
|
||||||
|
|
||||||
|
|||||||
@@ -266,7 +266,7 @@
|
|||||||
" self.backbone = backbone\n",
|
" self.backbone = backbone\n",
|
||||||
"\n",
|
"\n",
|
||||||
" self.head = nn.Sequential(\n",
|
" self.head = nn.Sequential(\n",
|
||||||
" nn.Linear(self.feature_dim * 4, 512),\n",
|
" nn.Linear(self.feature_dim * 4, 1024),\n",
|
||||||
" nn.ReLU(inplace=True),\n",
|
" nn.ReLU(inplace=True),\n",
|
||||||
" nn.Dropout(dropout_rate),\n",
|
" nn.Dropout(dropout_rate),\n",
|
||||||
" nn.Linear(1024, 512),\n",
|
" nn.Linear(1024, 512),\n",
|
||||||
@@ -275,7 +275,7 @@
|
|||||||
" nn.Linear(512, 256),\n",
|
" nn.Linear(512, 256),\n",
|
||||||
" nn.ReLU(inplace=True),\n",
|
" nn.ReLU(inplace=True),\n",
|
||||||
" nn.Dropout(dropout_rate),\n",
|
" nn.Dropout(dropout_rate),\n",
|
||||||
" nn.Linear(512, 9),\n",
|
" nn.Linear(256, 6),\n",
|
||||||
" )\n",
|
" )\n",
|
||||||
"\n",
|
"\n",
|
||||||
" def _normalize_sin_cos(self, _sin, _cos):\n",
|
" def _normalize_sin_cos(self, _sin, _cos):\n",
|
||||||
@@ -287,21 +287,21 @@
|
|||||||
" f2 = self.backbone(img2)\n",
|
" f2 = self.backbone(img2)\n",
|
||||||
" combined = torch.cat([f1, f2, torch.abs(f1 - f2), f1 * f2], dim=1)\n",
|
" combined = torch.cat([f1, f2, torch.abs(f1 - f2), f1 * f2], dim=1)\n",
|
||||||
"\n",
|
"\n",
|
||||||
" combined[:, (0, 1)] = torch.tanh(combined[:, (0, 1)]) * 10 # [-10; 10]\n",
|
" output = self.head(combined)\n",
|
||||||
" combined[:, (2, 3)] = self._normalize_sin_cos(torch.tanh(combined[:, 2]), torch.tanh(combined[:, 3]))\n",
|
"\n",
|
||||||
" combined[:, (4, 5)] = self._normalize_sin_cos(torch.tanh(combined[:, 4]), torch.tanh(combined[:, 5]))\n",
|
" output = torch.tanh(output) # [-1; 1]\n",
|
||||||
" combined[:, (6, 7)] = self._normalize_sin_cos(torch.tanh(combined[:, 6]), torch.tanh(combined[:, 7]))\n",
|
" modified = output.clone()\n",
|
||||||
" \n",
|
" modified[:, 2:5] = torch.mul(output[:, 2:5], torch.pi) # [-pi; pi]\n",
|
||||||
" return self.head(combined)\n",
|
"\n",
|
||||||
|
" return modified\n",
|
||||||
"\n",
|
"\n",
|
||||||
" def decode_output(self, output):\n",
|
" def decode_output(self, output):\n",
|
||||||
" tx = output[:, 0]\n",
|
" tx = output[:, 0]\n",
|
||||||
" ty = output[:, 1]\n",
|
" ty = output[:, 1]\n",
|
||||||
" scale = output[:, 8]\n",
|
" scale = output[:, 5]\n",
|
||||||
"\n",
|
" angle1 = output[:, 2]\n",
|
||||||
" angle1 = torch.atan2(output[:, 2], output[:, 3])\n",
|
" angle2 = output[:, 3]\n",
|
||||||
" angle2 = torch.atan2(output[:, 4], output[:, 5])\n",
|
" angle3 = output[:, 4]\n",
|
||||||
" angle3 = torch.atan2(output[:, 6], output[:, 7])\n",
|
|
||||||
"\n",
|
"\n",
|
||||||
" return torch.stack([tx, ty, angle1, angle2, angle3, scale], dim=1)\n",
|
" return torch.stack([tx, ty, angle1, angle2, angle3, scale], dim=1)\n",
|
||||||
"\n",
|
"\n",
|
||||||
@@ -325,33 +325,27 @@
|
|||||||
" self.trans_loss_weight = trans_loss_weight\n",
|
" self.trans_loss_weight = trans_loss_weight\n",
|
||||||
" self.scale_loss_weight = scale_loss_weight\n",
|
" self.scale_loss_weight = scale_loss_weight\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
" @staticmethod\n",
|
||||||
|
" def dot_angles(src, dest):\n",
|
||||||
|
" sin_src = torch.sin(src)\n",
|
||||||
|
" cos_src = torch.cos(src)\n",
|
||||||
|
" sin_dest = torch.sin(dest)\n",
|
||||||
|
" cos_dest = torch.cos(dest)\n",
|
||||||
|
" return sin_src * sin_dest + cos_src * cos_dest\n",
|
||||||
|
"\n",
|
||||||
" def forward(self, pred, target):\n",
|
" def forward(self, pred, target):\n",
|
||||||
" tx_loss = self.criterion(pred[:, 0], target[:, 0])\n",
|
" tx_loss = self.criterion(pred[:, 0], target[:, 0])\n",
|
||||||
" ty_loss = self.criterion(pred[:, 1], target[:, 1])\n",
|
" ty_loss = self.criterion(pred[:, 1], target[:, 1])\n",
|
||||||
"\n",
|
"\n",
|
||||||
" sin_rx_pred = pred[:, 2]\n",
|
" dot_rx = HomographyLoss6.dot_angles(pred[:, 2], target[:, 2])\n",
|
||||||
" cos_rx_pred = pred[:, 3]\n",
|
" dot_ry = HomographyLoss6.dot_angles(pred[:, 3], target[:, 3])\n",
|
||||||
" sin_ry_pred = pred[:, 4]\n",
|
" dot_rz = HomographyLoss6.dot_angles(pred[:, 4], target[:, 4])\n",
|
||||||
" cos_ry_pred = pred[:, 5]\n",
|
|
||||||
" sin_rz_pred = pred[:, 6]\n",
|
|
||||||
" cos_rz_pred = pred[:, 7]\n",
|
|
||||||
"\n",
|
"\n",
|
||||||
" sin_rx_target = torch.sin(target[:, 2])\n",
|
" rx_loss = self.criterion(dot_rx, torch.ones_like(dot_rx))\n",
|
||||||
" cos_rx_target = torch.cos(target[:, 2])\n",
|
" ry_loss = self.criterion(dot_ry, torch.ones_like(dot_ry))\n",
|
||||||
" sin_ry_target = torch.sin(target[:, 3])\n",
|
" rz_loss = self.criterion(dot_rz, torch.ones_like(dot_rz))\n",
|
||||||
" cos_ry_target = torch.cos(target[:, 3])\n",
|
|
||||||
" sin_rz_target = torch.sin(target[:, 4])\n",
|
|
||||||
" cos_rz_target = torch.cos(target[:, 4])\n",
|
|
||||||
"\n",
|
"\n",
|
||||||
" dot_rx = sin_rx_pred * sin_rx_target + cos_rx_pred * cos_rx_target\n",
|
" scale_loss = self.criterion(pred[:, 5], target[:, 5])\n",
|
||||||
" dot_ry = sin_ry_pred * sin_ry_target + cos_ry_pred * cos_ry_target\n",
|
|
||||||
" dot_rz = sin_rz_pred * sin_rz_target + cos_rz_pred * cos_rz_target\n",
|
|
||||||
"\n",
|
|
||||||
" rx_loss = (1 - dot_rx).mean()\n",
|
|
||||||
" ry_loss = (1 - dot_ry).mean()\n",
|
|
||||||
" rz_loss = (1 - dot_rz).mean()\n",
|
|
||||||
"\n",
|
|
||||||
" scale_loss = self.criterion(pred[:, 8], target[:, 5])\n",
|
|
||||||
"\n",
|
"\n",
|
||||||
" total_loss = (\n",
|
" total_loss = (\n",
|
||||||
" self.trans_loss_weight * (tx_loss + ty_loss) +\n",
|
" self.trans_loss_weight * (tx_loss + ty_loss) +\n",
|
||||||
@@ -361,14 +355,17 @@
|
|||||||
"\n",
|
"\n",
|
||||||
" return total_loss\n",
|
" return total_loss\n",
|
||||||
"\n",
|
"\n",
|
||||||
" def compute_mse_components(self, pred, target):\n",
|
" def compute_mse_components(self, decoded, target):\n",
|
||||||
" decoded = self.decode_output(pred)\n",
|
|
||||||
" tx_mse = self.criterion(decoded[:, 0], target[:, 0]).item()\n",
|
" tx_mse = self.criterion(decoded[:, 0], target[:, 0]).item()\n",
|
||||||
" ty_mse = self.criterion(decoded[:, 1], target[:, 1]).item()\n",
|
" ty_mse = self.criterion(decoded[:, 1], target[:, 1]).item()\n",
|
||||||
|
" \n",
|
||||||
|
" dot_rx = HomographyLoss6.dot_angles(decoded[:, 2], target[:, 2])\n",
|
||||||
|
" dot_ry = HomographyLoss6.dot_angles(decoded[:, 3], target[:, 3])\n",
|
||||||
|
" dot_rz = HomographyLoss6.dot_angles(decoded[:, 4], target[:, 4])\n",
|
||||||
"\n",
|
"\n",
|
||||||
" rx_mse = angular_difference(decoded[:, 2], target[:, 2]).item()\n",
|
" rx_mse = self.criterion(dot_rx, torch.ones_like(dot_rx)).item()\n",
|
||||||
" ry_mse = angular_difference(decoded[:, 3], target[:, 3]).item()\n",
|
" ry_mse = self.criterion(dot_ry, torch.ones_like(dot_ry)).item()\n",
|
||||||
" rz_mse = angular_difference(decoded[:, 4], target[:, 4]).item()\n",
|
" rz_mse = self.criterion(dot_rz, torch.ones_like(dot_rz)).item()\n",
|
||||||
"\n",
|
"\n",
|
||||||
" scale_mse = self.criterion(decoded[:, 5], target[:, 5]).item()\n",
|
" scale_mse = self.criterion(decoded[:, 5], target[:, 5]).item()\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ class HomographyCNN6(nn.Module):
|
|||||||
self.backbone = backbone
|
self.backbone = backbone
|
||||||
|
|
||||||
self.head = nn.Sequential(
|
self.head = nn.Sequential(
|
||||||
nn.Linear(self.feature_dim * 4, 512),
|
nn.Linear(self.feature_dim * 4, 1024),
|
||||||
nn.ReLU(inplace=True),
|
nn.ReLU(inplace=True),
|
||||||
nn.Dropout(dropout_rate),
|
nn.Dropout(dropout_rate),
|
||||||
nn.Linear(1024, 512),
|
nn.Linear(1024, 512),
|
||||||
@@ -27,7 +27,7 @@ class HomographyCNN6(nn.Module):
|
|||||||
nn.Linear(512, 256),
|
nn.Linear(512, 256),
|
||||||
nn.ReLU(inplace=True),
|
nn.ReLU(inplace=True),
|
||||||
nn.Dropout(dropout_rate),
|
nn.Dropout(dropout_rate),
|
||||||
nn.Linear(512, 9),
|
nn.Linear(256, 6),
|
||||||
)
|
)
|
||||||
|
|
||||||
def _normalize_sin_cos(self, _sin, _cos):
|
def _normalize_sin_cos(self, _sin, _cos):
|
||||||
@@ -39,21 +39,21 @@ class HomographyCNN6(nn.Module):
|
|||||||
f2 = self.backbone(img2)
|
f2 = self.backbone(img2)
|
||||||
combined = torch.cat([f1, f2, torch.abs(f1 - f2), f1 * f2], dim=1)
|
combined = torch.cat([f1, f2, torch.abs(f1 - f2), f1 * f2], dim=1)
|
||||||
|
|
||||||
combined[:, (0, 1)] = torch.tanh(combined[:, (0, 1)]) * 10 # [-10; 10]
|
output = self.head(combined)
|
||||||
combined[:, (2, 3)] = self._normalize_sin_cos(torch.tanh(combined[:, 2]), torch.tanh(combined[:, 3]))
|
|
||||||
combined[:, (4, 5)] = self._normalize_sin_cos(torch.tanh(combined[:, 4]), torch.tanh(combined[:, 5]))
|
output = torch.tanh(output) # [-1; 1]
|
||||||
combined[:, (6, 7)] = self._normalize_sin_cos(torch.tanh(combined[:, 6]), torch.tanh(combined[:, 7]))
|
modified = output.clone()
|
||||||
|
modified[:, 2:6] = torch.mul(output[:, 2:6], torch.pi) # [-pi; pi]
|
||||||
return self.head(combined)
|
|
||||||
|
return modified
|
||||||
|
|
||||||
def decode_output(self, output):
|
def decode_output(self, output):
|
||||||
tx = output[:, 0]
|
tx = output[:, 0]
|
||||||
ty = output[:, 1]
|
ty = output[:, 1]
|
||||||
scale = output[:, 8]
|
scale = output[:, 5]
|
||||||
|
angle1 = output[:, 2]
|
||||||
angle1 = torch.atan2(output[:, 2], output[:, 3])
|
angle2 = output[:, 3]
|
||||||
angle2 = torch.atan2(output[:, 4], output[:, 5])
|
angle3 = output[:, 4]
|
||||||
angle3 = torch.atan2(output[:, 6], output[:, 7])
|
|
||||||
|
|
||||||
return torch.stack([tx, ty, angle1, angle2, angle3, scale], dim=1)
|
return torch.stack([tx, ty, angle1, angle2, angle3, scale], dim=1)
|
||||||
|
|
||||||
@@ -77,33 +77,27 @@ class HomographyLoss6(nn.Module):
|
|||||||
self.trans_loss_weight = trans_loss_weight
|
self.trans_loss_weight = trans_loss_weight
|
||||||
self.scale_loss_weight = scale_loss_weight
|
self.scale_loss_weight = scale_loss_weight
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def dot_angles(src, dest):
|
||||||
|
sin_src = torch.sin(src)
|
||||||
|
cos_src = torch.cos(src)
|
||||||
|
sin_dest = torch.sin(dest)
|
||||||
|
cos_dest = torch.cos(dest)
|
||||||
|
return sin_src * sin_dest + cos_src * cos_dest
|
||||||
|
|
||||||
def forward(self, pred, target):
|
def forward(self, pred, target):
|
||||||
tx_loss = self.criterion(pred[:, 0], target[:, 0])
|
tx_loss = self.criterion(pred[:, 0], target[:, 0])
|
||||||
ty_loss = self.criterion(pred[:, 1], target[:, 1])
|
ty_loss = self.criterion(pred[:, 1], target[:, 1])
|
||||||
|
|
||||||
sin_rx_pred = pred[:, 2]
|
dot_rx = HomographyLoss6.dot_angles(pred[:, 2], target[:, 2])
|
||||||
cos_rx_pred = pred[:, 3]
|
dot_ry = HomographyLoss6.dot_angles(pred[:, 3], target[:, 3])
|
||||||
sin_ry_pred = pred[:, 4]
|
dot_rz = HomographyLoss6.dot_angles(pred[:, 4], target[:, 4])
|
||||||
cos_ry_pred = pred[:, 5]
|
|
||||||
sin_rz_pred = pred[:, 6]
|
|
||||||
cos_rz_pred = pred[:, 7]
|
|
||||||
|
|
||||||
sin_rx_target = torch.sin(target[:, 2])
|
rx_loss = self.criterion(dot_rx, torch.ones_like(dot_rx))
|
||||||
cos_rx_target = torch.cos(target[:, 2])
|
ry_loss = self.criterion(dot_ry, torch.ones_like(dot_ry))
|
||||||
sin_ry_target = torch.sin(target[:, 3])
|
rz_loss = self.criterion(dot_rz, torch.ones_like(dot_rz))
|
||||||
cos_ry_target = torch.cos(target[:, 3])
|
|
||||||
sin_rz_target = torch.sin(target[:, 4])
|
|
||||||
cos_rz_target = torch.cos(target[:, 4])
|
|
||||||
|
|
||||||
dot_rx = sin_rx_pred * sin_rx_target + cos_rx_pred * cos_rx_target
|
scale_loss = self.criterion(pred[:, 5], target[:, 5])
|
||||||
dot_ry = sin_ry_pred * sin_ry_target + cos_ry_pred * cos_ry_target
|
|
||||||
dot_rz = sin_rz_pred * sin_rz_target + cos_rz_pred * cos_rz_target
|
|
||||||
|
|
||||||
rx_loss = (1 - dot_rx).mean()
|
|
||||||
ry_loss = (1 - dot_ry).mean()
|
|
||||||
rz_loss = (1 - dot_rz).mean()
|
|
||||||
|
|
||||||
scale_loss = self.criterion(pred[:, 8], target[:, 5])
|
|
||||||
|
|
||||||
total_loss = (
|
total_loss = (
|
||||||
self.trans_loss_weight * (tx_loss + ty_loss) +
|
self.trans_loss_weight * (tx_loss + ty_loss) +
|
||||||
@@ -113,14 +107,17 @@ class HomographyLoss6(nn.Module):
|
|||||||
|
|
||||||
return total_loss
|
return total_loss
|
||||||
|
|
||||||
def compute_mse_components(self, pred, target):
|
def compute_mse_components(self, decoded, target):
|
||||||
decoded = self.decode_output(pred)
|
|
||||||
tx_mse = self.criterion(decoded[:, 0], target[:, 0]).item()
|
tx_mse = self.criterion(decoded[:, 0], target[:, 0]).item()
|
||||||
ty_mse = self.criterion(decoded[:, 1], target[:, 1]).item()
|
ty_mse = self.criterion(decoded[:, 1], target[:, 1]).item()
|
||||||
|
|
||||||
|
dot_rx = HomographyLoss6.dot_angles(decoded[:, 2], target[:, 2])
|
||||||
|
dot_ry = HomographyLoss6.dot_angles(decoded[:, 3], target[:, 3])
|
||||||
|
dot_rz = HomographyLoss6.dot_angles(decoded[:, 4], target[:, 4])
|
||||||
|
|
||||||
rx_mse = angular_difference(decoded[:, 2], target[:, 2]).item()
|
rx_mse = self.criterion(dot_rx, torch.ones_like(dot_rx)).item()
|
||||||
ry_mse = angular_difference(decoded[:, 3], target[:, 3]).item()
|
ry_mse = self.criterion(dot_ry, torch.ones_like(dot_ry)).item()
|
||||||
rz_mse = angular_difference(decoded[:, 4], target[:, 4]).item()
|
rz_mse = self.criterion(dot_rz, torch.ones_like(dot_rz)).item()
|
||||||
|
|
||||||
scale_mse = self.criterion(decoded[:, 5], target[:, 5]).item()
|
scale_mse = self.criterion(decoded[:, 5], target[:, 5]).item()
|
||||||
|
|
||||||
|
|||||||
15
models/SiaN/src/test_dataloader.py
Normal file
15
models/SiaN/src/test_dataloader.py
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
from ..src.dataloader import *
|
||||||
|
|
||||||
|
train_loader, val_loader = create_data_loaders(config['data_dir'])
|
||||||
|
batch = next(iter(train_loader))
|
||||||
|
google_img = batch['google_img'][0]
|
||||||
|
yandex_img = batch['yandex_img'][0]
|
||||||
|
|
||||||
|
# google_img.permute((1, 2, 0)) * 255
|
||||||
|
batch['homography_params'].mean(axis=0)
|
||||||
|
|
||||||
|
print(batch['homography_matrix'][0])
|
||||||
|
print(batch['homography_params'][0])
|
||||||
|
K = get_camera_matrix(config['image_size'][0], config['image_size'][1])
|
||||||
|
print(homography_params_to_matrix(batch['homography_params'][0], K))
|
||||||
|
print(matrix_to_homography_params(batch['homography_matrix'][0].numpy(), K))
|
||||||
Reference in New Issue
Block a user