diff --git a/models/SiaN/_schema.py b/models/SiaN/_schema.py index 0fc4462..c962fbc 100644 --- a/models/SiaN/_schema.py +++ b/models/SiaN/_schema.py @@ -43,6 +43,9 @@ Google/Yandex image pair loader with homography augmentation. - Batch dict with `google_img`, `yandex_img`, `homography_params`""" # code: ./src/dataloader.py +# code: ./src/test_dataloader.py + + # markdown """## Model diff --git a/models/SiaN/notebook.gen.ipynb b/models/SiaN/notebook.gen.ipynb index 6c2d035..8f796cc 100644 --- a/models/SiaN/notebook.gen.ipynb +++ b/models/SiaN/notebook.gen.ipynb @@ -266,7 +266,7 @@ " self.backbone = backbone\n", "\n", " self.head = nn.Sequential(\n", - " nn.Linear(self.feature_dim * 4, 512),\n", + " nn.Linear(self.feature_dim * 4, 1024),\n", " nn.ReLU(inplace=True),\n", " nn.Dropout(dropout_rate),\n", " nn.Linear(1024, 512),\n", @@ -275,7 +275,7 @@ " nn.Linear(512, 256),\n", " nn.ReLU(inplace=True),\n", " nn.Dropout(dropout_rate),\n", - " nn.Linear(512, 9),\n", + " nn.Linear(256, 6),\n", " )\n", "\n", " def _normalize_sin_cos(self, _sin, _cos):\n", @@ -287,21 +287,21 @@ " f2 = self.backbone(img2)\n", " combined = torch.cat([f1, f2, torch.abs(f1 - f2), f1 * f2], dim=1)\n", "\n", - " combined[:, (0, 1)] = torch.tanh(combined[:, (0, 1)]) * 10 # [-10; 10]\n", - " combined[:, (2, 3)] = self._normalize_sin_cos(torch.tanh(combined[:, 2]), torch.tanh(combined[:, 3]))\n", - " combined[:, (4, 5)] = self._normalize_sin_cos(torch.tanh(combined[:, 4]), torch.tanh(combined[:, 5]))\n", - " combined[:, (6, 7)] = self._normalize_sin_cos(torch.tanh(combined[:, 6]), torch.tanh(combined[:, 7]))\n", - " \n", - " return self.head(combined)\n", + " output = self.head(combined)\n", + "\n", + " output = torch.tanh(output) # [-1; 1]\n", + " modified = output.clone()\n", + " modified[:, 2:5] = torch.mul(output[:, 2:5], torch.pi) # [-pi; pi]\n", + "\n", + " return modified\n", "\n", " def decode_output(self, output):\n", " tx = output[:, 0]\n", " ty = output[:, 1]\n", - " scale = output[:, 8]\n", - "\n", - " angle1 = torch.atan2(output[:, 2], output[:, 3])\n", - " angle2 = torch.atan2(output[:, 4], output[:, 5])\n", - " angle3 = torch.atan2(output[:, 6], output[:, 7])\n", + " scale = output[:, 5]\n", + " angle1 = output[:, 2]\n", + " angle2 = output[:, 3]\n", + " angle3 = output[:, 4]\n", "\n", " return torch.stack([tx, ty, angle1, angle2, angle3, scale], dim=1)\n", "\n", @@ -325,33 +325,27 @@ " self.trans_loss_weight = trans_loss_weight\n", " self.scale_loss_weight = scale_loss_weight\n", "\n", + " @staticmethod\n", + " def dot_angles(src, dest):\n", + " sin_src = torch.sin(src)\n", + " cos_src = torch.cos(src)\n", + " sin_dest = torch.sin(dest)\n", + " cos_dest = torch.cos(dest)\n", + " return sin_src * sin_dest + cos_src * cos_dest\n", + "\n", " def forward(self, pred, target):\n", " tx_loss = self.criterion(pred[:, 0], target[:, 0])\n", " ty_loss = self.criterion(pred[:, 1], target[:, 1])\n", "\n", - " sin_rx_pred = pred[:, 2]\n", - " cos_rx_pred = pred[:, 3]\n", - " sin_ry_pred = pred[:, 4]\n", - " cos_ry_pred = pred[:, 5]\n", - " sin_rz_pred = pred[:, 6]\n", - " cos_rz_pred = pred[:, 7]\n", + " dot_rx = HomographyLoss6.dot_angles(pred[:, 2], target[:, 2])\n", + " dot_ry = HomographyLoss6.dot_angles(pred[:, 3], target[:, 3])\n", + " dot_rz = HomographyLoss6.dot_angles(pred[:, 4], target[:, 4])\n", "\n", - " sin_rx_target = torch.sin(target[:, 2])\n", - " cos_rx_target = torch.cos(target[:, 2])\n", - " sin_ry_target = torch.sin(target[:, 3])\n", - " cos_ry_target = torch.cos(target[:, 3])\n", - " sin_rz_target = torch.sin(target[:, 4])\n", - " cos_rz_target = torch.cos(target[:, 4])\n", + " rx_loss = self.criterion(dot_rx, torch.ones_like(dot_rx))\n", + " ry_loss = self.criterion(dot_ry, torch.ones_like(dot_ry))\n", + " rz_loss = self.criterion(dot_rz, torch.ones_like(dot_rz))\n", "\n", - " dot_rx = sin_rx_pred * sin_rx_target + cos_rx_pred * cos_rx_target\n", - " dot_ry = sin_ry_pred * sin_ry_target + cos_ry_pred * cos_ry_target\n", - " dot_rz = sin_rz_pred * sin_rz_target + cos_rz_pred * cos_rz_target\n", - "\n", - " rx_loss = (1 - dot_rx).mean()\n", - " ry_loss = (1 - dot_ry).mean()\n", - " rz_loss = (1 - dot_rz).mean()\n", - "\n", - " scale_loss = self.criterion(pred[:, 8], target[:, 5])\n", + " scale_loss = self.criterion(pred[:, 5], target[:, 5])\n", "\n", " total_loss = (\n", " self.trans_loss_weight * (tx_loss + ty_loss) +\n", @@ -361,14 +355,17 @@ "\n", " return total_loss\n", "\n", - " def compute_mse_components(self, pred, target):\n", - " decoded = self.decode_output(pred)\n", + " def compute_mse_components(self, decoded, target):\n", " tx_mse = self.criterion(decoded[:, 0], target[:, 0]).item()\n", " ty_mse = self.criterion(decoded[:, 1], target[:, 1]).item()\n", + " \n", + " dot_rx = HomographyLoss6.dot_angles(decoded[:, 2], target[:, 2])\n", + " dot_ry = HomographyLoss6.dot_angles(decoded[:, 3], target[:, 3])\n", + " dot_rz = HomographyLoss6.dot_angles(decoded[:, 4], target[:, 4])\n", "\n", - " rx_mse = angular_difference(decoded[:, 2], target[:, 2]).item()\n", - " ry_mse = angular_difference(decoded[:, 3], target[:, 3]).item()\n", - " rz_mse = angular_difference(decoded[:, 4], target[:, 4]).item()\n", + " rx_mse = self.criterion(dot_rx, torch.ones_like(dot_rx)).item()\n", + " ry_mse = self.criterion(dot_ry, torch.ones_like(dot_ry)).item()\n", + " rz_mse = self.criterion(dot_rz, torch.ones_like(dot_rz)).item()\n", "\n", " scale_mse = self.criterion(decoded[:, 5], target[:, 5]).item()\n", "\n", diff --git a/models/SiaN/src/model.py b/models/SiaN/src/model.py index 7c81726..b175d54 100644 --- a/models/SiaN/src/model.py +++ b/models/SiaN/src/model.py @@ -18,7 +18,7 @@ class HomographyCNN6(nn.Module): self.backbone = backbone self.head = nn.Sequential( - nn.Linear(self.feature_dim * 4, 512), + nn.Linear(self.feature_dim * 4, 1024), nn.ReLU(inplace=True), nn.Dropout(dropout_rate), nn.Linear(1024, 512), @@ -27,7 +27,7 @@ class HomographyCNN6(nn.Module): nn.Linear(512, 256), nn.ReLU(inplace=True), nn.Dropout(dropout_rate), - nn.Linear(512, 9), + nn.Linear(256, 6), ) def _normalize_sin_cos(self, _sin, _cos): @@ -39,21 +39,21 @@ class HomographyCNN6(nn.Module): f2 = self.backbone(img2) combined = torch.cat([f1, f2, torch.abs(f1 - f2), f1 * f2], dim=1) - combined[:, (0, 1)] = torch.tanh(combined[:, (0, 1)]) * 10 # [-10; 10] - combined[:, (2, 3)] = self._normalize_sin_cos(torch.tanh(combined[:, 2]), torch.tanh(combined[:, 3])) - combined[:, (4, 5)] = self._normalize_sin_cos(torch.tanh(combined[:, 4]), torch.tanh(combined[:, 5])) - combined[:, (6, 7)] = self._normalize_sin_cos(torch.tanh(combined[:, 6]), torch.tanh(combined[:, 7])) - - return self.head(combined) + output = self.head(combined) + + output = torch.tanh(output) # [-1; 1] + modified = output.clone() + modified[:, 2:6] = torch.mul(output[:, 2:6], torch.pi) # [-pi; pi] + + return modified def decode_output(self, output): tx = output[:, 0] ty = output[:, 1] - scale = output[:, 8] - - angle1 = torch.atan2(output[:, 2], output[:, 3]) - angle2 = torch.atan2(output[:, 4], output[:, 5]) - angle3 = torch.atan2(output[:, 6], output[:, 7]) + scale = output[:, 5] + angle1 = output[:, 2] + angle2 = output[:, 3] + angle3 = output[:, 4] return torch.stack([tx, ty, angle1, angle2, angle3, scale], dim=1) @@ -77,33 +77,27 @@ class HomographyLoss6(nn.Module): self.trans_loss_weight = trans_loss_weight self.scale_loss_weight = scale_loss_weight + @staticmethod + def dot_angles(src, dest): + sin_src = torch.sin(src) + cos_src = torch.cos(src) + sin_dest = torch.sin(dest) + cos_dest = torch.cos(dest) + return sin_src * sin_dest + cos_src * cos_dest + def forward(self, pred, target): tx_loss = self.criterion(pred[:, 0], target[:, 0]) ty_loss = self.criterion(pred[:, 1], target[:, 1]) - sin_rx_pred = pred[:, 2] - cos_rx_pred = pred[:, 3] - sin_ry_pred = pred[:, 4] - cos_ry_pred = pred[:, 5] - sin_rz_pred = pred[:, 6] - cos_rz_pred = pred[:, 7] + dot_rx = HomographyLoss6.dot_angles(pred[:, 2], target[:, 2]) + dot_ry = HomographyLoss6.dot_angles(pred[:, 3], target[:, 3]) + dot_rz = HomographyLoss6.dot_angles(pred[:, 4], target[:, 4]) - sin_rx_target = torch.sin(target[:, 2]) - cos_rx_target = torch.cos(target[:, 2]) - sin_ry_target = torch.sin(target[:, 3]) - cos_ry_target = torch.cos(target[:, 3]) - sin_rz_target = torch.sin(target[:, 4]) - cos_rz_target = torch.cos(target[:, 4]) + rx_loss = self.criterion(dot_rx, torch.ones_like(dot_rx)) + ry_loss = self.criterion(dot_ry, torch.ones_like(dot_ry)) + rz_loss = self.criterion(dot_rz, torch.ones_like(dot_rz)) - dot_rx = sin_rx_pred * sin_rx_target + cos_rx_pred * cos_rx_target - dot_ry = sin_ry_pred * sin_ry_target + cos_ry_pred * cos_ry_target - dot_rz = sin_rz_pred * sin_rz_target + cos_rz_pred * cos_rz_target - - rx_loss = (1 - dot_rx).mean() - ry_loss = (1 - dot_ry).mean() - rz_loss = (1 - dot_rz).mean() - - scale_loss = self.criterion(pred[:, 8], target[:, 5]) + scale_loss = self.criterion(pred[:, 5], target[:, 5]) total_loss = ( self.trans_loss_weight * (tx_loss + ty_loss) + @@ -113,14 +107,17 @@ class HomographyLoss6(nn.Module): return total_loss - def compute_mse_components(self, pred, target): - decoded = self.decode_output(pred) + def compute_mse_components(self, decoded, target): tx_mse = self.criterion(decoded[:, 0], target[:, 0]).item() ty_mse = self.criterion(decoded[:, 1], target[:, 1]).item() + + dot_rx = HomographyLoss6.dot_angles(decoded[:, 2], target[:, 2]) + dot_ry = HomographyLoss6.dot_angles(decoded[:, 3], target[:, 3]) + dot_rz = HomographyLoss6.dot_angles(decoded[:, 4], target[:, 4]) - rx_mse = angular_difference(decoded[:, 2], target[:, 2]).item() - ry_mse = angular_difference(decoded[:, 3], target[:, 3]).item() - rz_mse = angular_difference(decoded[:, 4], target[:, 4]).item() + rx_mse = self.criterion(dot_rx, torch.ones_like(dot_rx)).item() + ry_mse = self.criterion(dot_ry, torch.ones_like(dot_ry)).item() + rz_mse = self.criterion(dot_rz, torch.ones_like(dot_rz)).item() scale_mse = self.criterion(decoded[:, 5], target[:, 5]).item() diff --git a/models/SiaN/src/test_dataloader.py b/models/SiaN/src/test_dataloader.py new file mode 100644 index 0000000..7e7d685 --- /dev/null +++ b/models/SiaN/src/test_dataloader.py @@ -0,0 +1,15 @@ +from ..src.dataloader import * + +train_loader, val_loader = create_data_loaders(config['data_dir']) +batch = next(iter(train_loader)) +google_img = batch['google_img'][0] +yandex_img = batch['yandex_img'][0] + +# google_img.permute((1, 2, 0)) * 255 +batch['homography_params'].mean(axis=0) + +print(batch['homography_matrix'][0]) +print(batch['homography_params'][0]) +K = get_camera_matrix(config['image_size'][0], config['image_size'][1]) +print(homography_params_to_matrix(batch['homography_params'][0], K)) +print(matrix_to_homography_params(batch['homography_matrix'][0].numpy(), K))