feat(SiaN): change output angles to sin/cos values
This commit is contained in:
File diff suppressed because it is too large
Load Diff
@@ -9,6 +9,12 @@ IMG_DIR = os.path.join(config["output_dir"], "images")
|
||||
os.makedirs(IMG_DIR, exist_ok=True)
|
||||
|
||||
|
||||
def angular_difference(pred_angles, target_angles):
|
||||
diff = pred_angles - target_angles
|
||||
diff = torch.atan2(torch.sin(diff), torch.cos(diff))
|
||||
return torch.abs(diff)
|
||||
|
||||
|
||||
def analyze_training(trainer):
|
||||
print("=== Training Analysis ===\n")
|
||||
|
||||
@@ -26,7 +32,7 @@ def analyze_training(trainer):
|
||||
trainer.model.eval()
|
||||
|
||||
n_samples = 50
|
||||
names = ["rx", "ry", "rz", "tx", "ty", "scale"]
|
||||
names = ["tx", "ty", "rx", "ry", "rz", "scale"]
|
||||
|
||||
with torch.no_grad():
|
||||
all_errors = [[] for _ in range(6)]
|
||||
@@ -42,11 +48,23 @@ def analyze_training(trainer):
|
||||
yandex_img = batch["yandex_img"].to(trainer.device)
|
||||
target_params = batch["homography_params"].to(trainer.device)
|
||||
pred_params = trainer.model(google_img, yandex_img)
|
||||
decoded_pred = trainer.model.decode_output(pred_params)
|
||||
|
||||
tx_error = torch.abs(decoded_pred[:, 0] - target_params[:, 0]).item()
|
||||
ty_error = torch.abs(decoded_pred[:, 1] - target_params[:, 1]).item()
|
||||
rx_error = angular_difference(decoded_pred[:, 2], target_params[:, 2]).item()
|
||||
ry_error = angular_difference(decoded_pred[:, 3], target_params[:, 3]).item()
|
||||
rz_error = angular_difference(decoded_pred[:, 4], target_params[:, 4]).item()
|
||||
scale_error = torch.abs(decoded_pred[:, 5] - target_params[:, 5]).item()
|
||||
|
||||
errors = [tx_error, ty_error, rx_error, ry_error, rz_error, scale_error]
|
||||
targets = target_params[0].cpu().numpy()
|
||||
preds = decoded_pred[0].cpu().numpy()
|
||||
|
||||
for j in range(6):
|
||||
all_errors[j].append(torch.abs(pred_params[0, j] - target_params[0, j]).item())
|
||||
all_targets[j].append(target_params[0, j].item())
|
||||
all_preds[j].append(pred_params[0, j].item())
|
||||
all_errors[j].append(errors[j])
|
||||
all_targets[j].append(targets[j])
|
||||
all_preds[j].append(preds[j])
|
||||
|
||||
mean_errors = [np.mean(all_errors[i]) for i in range(6)]
|
||||
std_errors = [np.std(all_errors[i]) for i in range(6)]
|
||||
@@ -144,10 +162,18 @@ def analyze_training(trainer):
|
||||
yandex_img = batch["yandex_img"].to(trainer.device)
|
||||
target_params = batch["homography_params"].to(trainer.device)
|
||||
pred_params = trainer.model(google_img, yandex_img)
|
||||
decoded_pred = trainer.model.decode_output(pred_params)
|
||||
|
||||
errors = torch.abs(pred_params[0] - target_params[0]).cpu().numpy()
|
||||
tx_error = torch.abs(decoded_pred[:, 0] - target_params[:, 0]).cpu().numpy()
|
||||
ty_error = torch.abs(decoded_pred[:, 1] - target_params[:, 1]).cpu().numpy()
|
||||
rx_error = angular_difference(decoded_pred[:, 2], target_params[:, 2]).cpu().numpy()
|
||||
ry_error = angular_difference(decoded_pred[:, 3], target_params[:, 3]).cpu().numpy()
|
||||
rz_error = angular_difference(decoded_pred[:, 4], target_params[:, 4]).cpu().numpy()
|
||||
scale_error = torch.abs(decoded_pred[:, 5] - target_params[:, 5]).cpu().numpy()
|
||||
|
||||
errors = np.array([tx_error[0], ty_error[0], rx_error[0], ry_error[0], rz_error[0], scale_error[0]])
|
||||
targets = target_params[0].cpu().numpy()
|
||||
preds = pred_params[0].cpu().numpy()
|
||||
preds = decoded_pred[0].cpu().numpy()
|
||||
|
||||
fig, axes = plt.subplots(2, 2, figsize=(12, 10))
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ class HomographyCNN6(nn.Module):
|
||||
nn.Linear(512, 256),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Dropout(dropout_rate),
|
||||
nn.Linear(256, 6),
|
||||
nn.Linear(256, 9),
|
||||
)
|
||||
|
||||
def forward(self, img1, img2):
|
||||
@@ -27,14 +27,95 @@ class HomographyCNN6(nn.Module):
|
||||
combined = torch.cat([f1, f2, torch.abs(f1 - f2), f1 * f2], dim=1)
|
||||
return self.head(combined)
|
||||
|
||||
def decode_output(self, output):
|
||||
tx, ty = output[:, 0], output[:, 1]
|
||||
sin1, cos1 = torch.tanh(output[:, 2]), torch.tanh(output[:, 3])
|
||||
sin2, cos2 = torch.tanh(output[:, 4]), torch.tanh(output[:, 5])
|
||||
sin3, cos3 = torch.tanh(output[:, 6]), torch.tanh(output[:, 7])
|
||||
scale = output[:, 8]
|
||||
|
||||
angle1 = torch.atan2(sin1, cos1)
|
||||
angle2 = torch.atan2(sin2, cos2)
|
||||
angle3 = torch.atan2(sin3, cos3)
|
||||
|
||||
return torch.stack([tx, ty, angle1, angle2, angle3, scale], dim=1)
|
||||
|
||||
|
||||
class HomographyLoss6(nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self, angle_loss_weight=1.0, trans_loss_weight=1.0, scale_loss_weight=1.0):
|
||||
super().__init__()
|
||||
self.criterion = nn.MSELoss()
|
||||
self.angle_loss_weight = angle_loss_weight
|
||||
self.trans_loss_weight = trans_loss_weight
|
||||
self.scale_loss_weight = scale_loss_weight
|
||||
|
||||
def forward(self, pred, target):
|
||||
return self.criterion(pred, target)
|
||||
tx_loss = self.criterion(pred[:, 0], target[:, 0])
|
||||
ty_loss = self.criterion(pred[:, 1], target[:, 1])
|
||||
|
||||
sin1_pred, cos1_pred = pred[:, 2], pred[:, 3]
|
||||
sin2_pred, cos2_pred = pred[:, 4], pred[:, 5]
|
||||
sin3_pred, cos3_pred = pred[:, 6], pred[:, 7]
|
||||
|
||||
sin1_target = torch.sin(target[:, 2])
|
||||
cos1_target = torch.cos(target[:, 2])
|
||||
sin2_target = torch.sin(target[:, 3])
|
||||
cos2_target = torch.cos(target[:, 3])
|
||||
sin3_target = torch.sin(target[:, 4])
|
||||
cos3_target = torch.cos(target[:, 4])
|
||||
|
||||
sin1_pred_t = torch.tanh(sin1_pred)
|
||||
cos1_pred_t = torch.tanh(cos1_pred)
|
||||
sin2_pred_t = torch.tanh(sin2_pred)
|
||||
cos2_pred_t = torch.tanh(cos2_pred)
|
||||
sin3_pred_t = torch.tanh(sin3_pred)
|
||||
cos3_pred_t = torch.tanh(cos3_pred)
|
||||
|
||||
angle1_loss = (1 - (sin1_pred_t * sin1_target + cos1_pred_t * cos1_target)).mean()
|
||||
angle2_loss = (1 - (sin2_pred_t * sin2_target + cos2_pred_t * cos2_target)).mean()
|
||||
angle3_loss = (1 - (sin3_pred_t * sin3_target + cos3_pred_t * cos3_target)).mean()
|
||||
|
||||
scale_loss = self.criterion(pred[:, 8], target[:, 5])
|
||||
|
||||
total_loss = (
|
||||
self.trans_loss_weight * (tx_loss + ty_loss) +
|
||||
self.angle_loss_weight * (angle1_loss + angle2_loss + angle3_loss) +
|
||||
self.scale_loss_weight * scale_loss
|
||||
)
|
||||
|
||||
return total_loss
|
||||
|
||||
def compute_mse_components(self, pred, target):
|
||||
tx_mse = self.criterion(pred[:, 0], target[:, 0]).item()
|
||||
ty_mse = self.criterion(pred[:, 1], target[:, 1]).item()
|
||||
|
||||
sin1_target = torch.sin(target[:, 2])
|
||||
cos1_target = torch.cos(target[:, 2])
|
||||
sin2_target = torch.sin(target[:, 3])
|
||||
cos2_target = torch.cos(target[:, 3])
|
||||
sin3_target = torch.sin(target[:, 4])
|
||||
cos3_target = torch.cos(target[:, 4])
|
||||
|
||||
sin1_pred_t = torch.tanh(pred[:, 2])
|
||||
cos1_pred_t = torch.tanh(pred[:, 3])
|
||||
sin2_pred_t = torch.tanh(pred[:, 4])
|
||||
cos2_pred_t = torch.tanh(pred[:, 5])
|
||||
sin3_pred_t = torch.tanh(pred[:, 6])
|
||||
cos3_pred_t = torch.tanh(pred[:, 7])
|
||||
|
||||
angle1_loss = (1 - (sin1_pred_t * sin1_target + cos1_pred_t * cos1_target)).mean().item()
|
||||
angle2_loss = (1 - (sin2_pred_t * sin2_target + cos2_pred_t * cos2_target)).mean().item()
|
||||
angle3_loss = (1 - (sin3_pred_t * sin3_target + cos3_pred_t * cos3_target)).mean().item()
|
||||
|
||||
scale_mse = self.criterion(pred[:, 8], target[:, 5]).item()
|
||||
|
||||
avg_angle_loss = (angle1_loss + angle2_loss + angle3_loss) / 3
|
||||
|
||||
return {
|
||||
'trans': (tx_mse + ty_mse) / 2,
|
||||
'angle': avg_angle_loss,
|
||||
'scale': scale_mse
|
||||
}
|
||||
|
||||
|
||||
def count_parameters(model):
|
||||
|
||||
@@ -49,9 +49,10 @@ class HomographyTrainer:
|
||||
total_loss += loss.item() * google_img.size(0)
|
||||
total_samples += google_img.size(0)
|
||||
|
||||
mse_trans_sum += torch.mean((output[:, 3:5] - target[:, 3:5]) ** 2).item() * google_img.size(0)
|
||||
mse_angle_sum += torch.mean((output[:, 0:3] - target[:, 0:3]) ** 2).item() * google_img.size(0)
|
||||
mse_scale_sum += torch.mean((output[:, 5:6] - target[:, 5:6]) ** 2).item() * google_img.size(0)
|
||||
mse_components = self.criterion.compute_mse_components(output, target)
|
||||
mse_trans_sum += mse_components['trans'] * google_img.size(0)
|
||||
mse_angle_sum += mse_components['angle'] * google_img.size(0)
|
||||
mse_scale_sum += mse_components['scale'] * google_img.size(0)
|
||||
|
||||
pbar.set_postfix({"loss": loss.item()})
|
||||
|
||||
@@ -75,9 +76,10 @@ class HomographyTrainer:
|
||||
total_loss += loss.item() * google_img.size(0)
|
||||
total_samples += google_img.size(0)
|
||||
|
||||
mse_trans_sum += torch.mean((output[:, 3:5] - target[:, 3:5]) ** 2).item() * google_img.size(0)
|
||||
mse_angle_sum += torch.mean((output[:, 0:3] - target[:, 0:3]) ** 2).item() * google_img.size(0)
|
||||
mse_scale_sum += torch.mean((output[:, 5:6] - target[:, 5:6]) ** 2).item() * google_img.size(0)
|
||||
mse_components = self.criterion.compute_mse_components(output, target)
|
||||
mse_trans_sum += mse_components['trans'] * google_img.size(0)
|
||||
mse_angle_sum += mse_components['angle'] * google_img.size(0)
|
||||
mse_scale_sum += mse_components['scale'] * google_img.size(0)
|
||||
|
||||
self.val_mse_trans.append(mse_trans_sum / total_samples)
|
||||
self.val_mse_angle.append(mse_angle_sum / total_samples)
|
||||
|
||||
Reference in New Issue
Block a user