feat(SiaN): change output angles to sin/cos values

This commit is contained in:
2026-04-05 13:10:54 +03:00
parent fa4c4b83ae
commit daee1767fb
4 changed files with 947 additions and 868 deletions

File diff suppressed because it is too large Load Diff

View File

@@ -9,6 +9,12 @@ IMG_DIR = os.path.join(config["output_dir"], "images")
os.makedirs(IMG_DIR, exist_ok=True)
def angular_difference(pred_angles, target_angles):
diff = pred_angles - target_angles
diff = torch.atan2(torch.sin(diff), torch.cos(diff))
return torch.abs(diff)
def analyze_training(trainer):
print("=== Training Analysis ===\n")
@@ -26,7 +32,7 @@ def analyze_training(trainer):
trainer.model.eval()
n_samples = 50
names = ["rx", "ry", "rz", "tx", "ty", "scale"]
names = ["tx", "ty", "rx", "ry", "rz", "scale"]
with torch.no_grad():
all_errors = [[] for _ in range(6)]
@@ -42,11 +48,23 @@ def analyze_training(trainer):
yandex_img = batch["yandex_img"].to(trainer.device)
target_params = batch["homography_params"].to(trainer.device)
pred_params = trainer.model(google_img, yandex_img)
decoded_pred = trainer.model.decode_output(pred_params)
tx_error = torch.abs(decoded_pred[:, 0] - target_params[:, 0]).item()
ty_error = torch.abs(decoded_pred[:, 1] - target_params[:, 1]).item()
rx_error = angular_difference(decoded_pred[:, 2], target_params[:, 2]).item()
ry_error = angular_difference(decoded_pred[:, 3], target_params[:, 3]).item()
rz_error = angular_difference(decoded_pred[:, 4], target_params[:, 4]).item()
scale_error = torch.abs(decoded_pred[:, 5] - target_params[:, 5]).item()
errors = [tx_error, ty_error, rx_error, ry_error, rz_error, scale_error]
targets = target_params[0].cpu().numpy()
preds = decoded_pred[0].cpu().numpy()
for j in range(6):
all_errors[j].append(torch.abs(pred_params[0, j] - target_params[0, j]).item())
all_targets[j].append(target_params[0, j].item())
all_preds[j].append(pred_params[0, j].item())
all_errors[j].append(errors[j])
all_targets[j].append(targets[j])
all_preds[j].append(preds[j])
mean_errors = [np.mean(all_errors[i]) for i in range(6)]
std_errors = [np.std(all_errors[i]) for i in range(6)]
@@ -144,10 +162,18 @@ def analyze_training(trainer):
yandex_img = batch["yandex_img"].to(trainer.device)
target_params = batch["homography_params"].to(trainer.device)
pred_params = trainer.model(google_img, yandex_img)
decoded_pred = trainer.model.decode_output(pred_params)
errors = torch.abs(pred_params[0] - target_params[0]).cpu().numpy()
tx_error = torch.abs(decoded_pred[:, 0] - target_params[:, 0]).cpu().numpy()
ty_error = torch.abs(decoded_pred[:, 1] - target_params[:, 1]).cpu().numpy()
rx_error = angular_difference(decoded_pred[:, 2], target_params[:, 2]).cpu().numpy()
ry_error = angular_difference(decoded_pred[:, 3], target_params[:, 3]).cpu().numpy()
rz_error = angular_difference(decoded_pred[:, 4], target_params[:, 4]).cpu().numpy()
scale_error = torch.abs(decoded_pred[:, 5] - target_params[:, 5]).cpu().numpy()
errors = np.array([tx_error[0], ty_error[0], rx_error[0], ry_error[0], rz_error[0], scale_error[0]])
targets = target_params[0].cpu().numpy()
preds = pred_params[0].cpu().numpy()
preds = decoded_pred[0].cpu().numpy()
fig, axes = plt.subplots(2, 2, figsize=(12, 10))

View File

@@ -18,7 +18,7 @@ class HomographyCNN6(nn.Module):
nn.Linear(512, 256),
nn.ReLU(inplace=True),
nn.Dropout(dropout_rate),
nn.Linear(256, 6),
nn.Linear(256, 9),
)
def forward(self, img1, img2):
@@ -27,14 +27,95 @@ class HomographyCNN6(nn.Module):
combined = torch.cat([f1, f2, torch.abs(f1 - f2), f1 * f2], dim=1)
return self.head(combined)
def decode_output(self, output):
tx, ty = output[:, 0], output[:, 1]
sin1, cos1 = torch.tanh(output[:, 2]), torch.tanh(output[:, 3])
sin2, cos2 = torch.tanh(output[:, 4]), torch.tanh(output[:, 5])
sin3, cos3 = torch.tanh(output[:, 6]), torch.tanh(output[:, 7])
scale = output[:, 8]
angle1 = torch.atan2(sin1, cos1)
angle2 = torch.atan2(sin2, cos2)
angle3 = torch.atan2(sin3, cos3)
return torch.stack([tx, ty, angle1, angle2, angle3, scale], dim=1)
class HomographyLoss6(nn.Module):
def __init__(self):
def __init__(self, angle_loss_weight=1.0, trans_loss_weight=1.0, scale_loss_weight=1.0):
super().__init__()
self.criterion = nn.MSELoss()
self.angle_loss_weight = angle_loss_weight
self.trans_loss_weight = trans_loss_weight
self.scale_loss_weight = scale_loss_weight
def forward(self, pred, target):
return self.criterion(pred, target)
tx_loss = self.criterion(pred[:, 0], target[:, 0])
ty_loss = self.criterion(pred[:, 1], target[:, 1])
sin1_pred, cos1_pred = pred[:, 2], pred[:, 3]
sin2_pred, cos2_pred = pred[:, 4], pred[:, 5]
sin3_pred, cos3_pred = pred[:, 6], pred[:, 7]
sin1_target = torch.sin(target[:, 2])
cos1_target = torch.cos(target[:, 2])
sin2_target = torch.sin(target[:, 3])
cos2_target = torch.cos(target[:, 3])
sin3_target = torch.sin(target[:, 4])
cos3_target = torch.cos(target[:, 4])
sin1_pred_t = torch.tanh(sin1_pred)
cos1_pred_t = torch.tanh(cos1_pred)
sin2_pred_t = torch.tanh(sin2_pred)
cos2_pred_t = torch.tanh(cos2_pred)
sin3_pred_t = torch.tanh(sin3_pred)
cos3_pred_t = torch.tanh(cos3_pred)
angle1_loss = (1 - (sin1_pred_t * sin1_target + cos1_pred_t * cos1_target)).mean()
angle2_loss = (1 - (sin2_pred_t * sin2_target + cos2_pred_t * cos2_target)).mean()
angle3_loss = (1 - (sin3_pred_t * sin3_target + cos3_pred_t * cos3_target)).mean()
scale_loss = self.criterion(pred[:, 8], target[:, 5])
total_loss = (
self.trans_loss_weight * (tx_loss + ty_loss) +
self.angle_loss_weight * (angle1_loss + angle2_loss + angle3_loss) +
self.scale_loss_weight * scale_loss
)
return total_loss
def compute_mse_components(self, pred, target):
tx_mse = self.criterion(pred[:, 0], target[:, 0]).item()
ty_mse = self.criterion(pred[:, 1], target[:, 1]).item()
sin1_target = torch.sin(target[:, 2])
cos1_target = torch.cos(target[:, 2])
sin2_target = torch.sin(target[:, 3])
cos2_target = torch.cos(target[:, 3])
sin3_target = torch.sin(target[:, 4])
cos3_target = torch.cos(target[:, 4])
sin1_pred_t = torch.tanh(pred[:, 2])
cos1_pred_t = torch.tanh(pred[:, 3])
sin2_pred_t = torch.tanh(pred[:, 4])
cos2_pred_t = torch.tanh(pred[:, 5])
sin3_pred_t = torch.tanh(pred[:, 6])
cos3_pred_t = torch.tanh(pred[:, 7])
angle1_loss = (1 - (sin1_pred_t * sin1_target + cos1_pred_t * cos1_target)).mean().item()
angle2_loss = (1 - (sin2_pred_t * sin2_target + cos2_pred_t * cos2_target)).mean().item()
angle3_loss = (1 - (sin3_pred_t * sin3_target + cos3_pred_t * cos3_target)).mean().item()
scale_mse = self.criterion(pred[:, 8], target[:, 5]).item()
avg_angle_loss = (angle1_loss + angle2_loss + angle3_loss) / 3
return {
'trans': (tx_mse + ty_mse) / 2,
'angle': avg_angle_loss,
'scale': scale_mse
}
def count_parameters(model):

View File

@@ -49,9 +49,10 @@ class HomographyTrainer:
total_loss += loss.item() * google_img.size(0)
total_samples += google_img.size(0)
mse_trans_sum += torch.mean((output[:, 3:5] - target[:, 3:5]) ** 2).item() * google_img.size(0)
mse_angle_sum += torch.mean((output[:, 0:3] - target[:, 0:3]) ** 2).item() * google_img.size(0)
mse_scale_sum += torch.mean((output[:, 5:6] - target[:, 5:6]) ** 2).item() * google_img.size(0)
mse_components = self.criterion.compute_mse_components(output, target)
mse_trans_sum += mse_components['trans'] * google_img.size(0)
mse_angle_sum += mse_components['angle'] * google_img.size(0)
mse_scale_sum += mse_components['scale'] * google_img.size(0)
pbar.set_postfix({"loss": loss.item()})
@@ -75,9 +76,10 @@ class HomographyTrainer:
total_loss += loss.item() * google_img.size(0)
total_samples += google_img.size(0)
mse_trans_sum += torch.mean((output[:, 3:5] - target[:, 3:5]) ** 2).item() * google_img.size(0)
mse_angle_sum += torch.mean((output[:, 0:3] - target[:, 0:3]) ** 2).item() * google_img.size(0)
mse_scale_sum += torch.mean((output[:, 5:6] - target[:, 5:6]) ** 2).item() * google_img.size(0)
mse_components = self.criterion.compute_mse_components(output, target)
mse_trans_sum += mse_components['trans'] * google_img.size(0)
mse_angle_sum += mse_components['angle'] * google_img.size(0)
mse_scale_sum += mse_components['scale'] * google_img.size(0)
self.val_mse_trans.append(mse_trans_sum / total_samples)
self.val_mse_angle.append(mse_angle_sum / total_samples)