feat(SiaN): improve analyzing
This commit is contained in:
2
models/SiaN/.gitignore
vendored
2
models/SiaN/.gitignore
vendored
@@ -1,3 +1,3 @@
|
||||
runs
|
||||
*.gen.py
|
||||
*.img
|
||||
*.png
|
||||
@@ -18,7 +18,6 @@ from torch.utils.tensorboard import SummaryWriter
|
||||
from torchvision import transforms, models
|
||||
from tqdm import tqdm
|
||||
|
||||
# code: ./src/utils.py
|
||||
# markdown
|
||||
"""# Configuration
|
||||
|
||||
@@ -28,8 +27,8 @@ Global settings for:
|
||||
- Model architecture options
|
||||
|
||||
Contains the `config` dictionary used across all modules."""
|
||||
# code: ./src/utils.py
|
||||
|
||||
# code: ./src/dataloader.py
|
||||
# markdown
|
||||
"""## Dataset
|
||||
|
||||
@@ -42,8 +41,8 @@ Google/Yandex image pair loader with homography augmentation.
|
||||
|
||||
**Returns:**
|
||||
- Batch dict with `google_img`, `yandex_img`, `homography_params`"""
|
||||
# code: ./src/dataloader.py
|
||||
|
||||
# code: ./src/model.py
|
||||
# markdown
|
||||
"""## Model
|
||||
|
||||
@@ -58,8 +57,8 @@ Google/Yandex image pair loader with homography augmentation.
|
||||
- Dual-branch CNN (Google + Yandex images)
|
||||
- Shared backbone (configurable: resnet18/34/50)
|
||||
- Fusion head with dropout regularization"""
|
||||
# code: ./src/model.py
|
||||
|
||||
# code: ./src/train.py
|
||||
# markdown
|
||||
"""## Training
|
||||
|
||||
@@ -75,8 +74,8 @@ Google/Yandex image pair loader with homography augmentation.
|
||||
**Checkpoint saving:**
|
||||
- `best_model.pt` — lowest validation loss
|
||||
- `checkpoint_epoch_N.pt` — periodic saves"""
|
||||
# code: ./src/train.py
|
||||
|
||||
# code: ./src/analyze.py
|
||||
# markdown
|
||||
"""## Analysis
|
||||
|
||||
@@ -85,8 +84,8 @@ Visualization and evaluation tools:
|
||||
- Training metrics plots (loss curves)
|
||||
- Prediction visualization on sample images
|
||||
- Error analysis and statistics"""
|
||||
# code: ./src/analyze.py
|
||||
|
||||
# code: ./src/main.py
|
||||
# markdown
|
||||
"""## Main Pipeline
|
||||
|
||||
@@ -101,6 +100,7 @@ Executes the full training workflow:
|
||||
- Model checkpoints in `runs/checkpoints/`
|
||||
- TensorBoard logs in `runs/`
|
||||
- Analysis plots"""
|
||||
# code: ./src/main.py
|
||||
|
||||
# # shell:
|
||||
# !zip artefacts.zip runs/gan_training/checkpoints/best_model.pt
|
||||
# !zip artefacts.zip runs/checkpoints/best_model.pt runs/images/ runs/events.*
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,7 +1,13 @@
|
||||
import os
|
||||
import torch
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
from .utils import config
|
||||
|
||||
IMG_DIR = os.path.join(config["output_dir"], "images")
|
||||
os.makedirs(IMG_DIR, exist_ok=True)
|
||||
|
||||
|
||||
def analyze_training(trainer):
|
||||
print("=== Training Analysis ===\n")
|
||||
@@ -11,42 +17,189 @@ def analyze_training(trainer):
|
||||
|
||||
print(f"\nBest val loss: {trainer.best_val_loss:.4f}")
|
||||
|
||||
best_model_path = os.path.join(config["output_dir"], "checkpoints", "best_model.pt")
|
||||
if os.path.exists(best_model_path):
|
||||
checkpoint = torch.load(best_model_path, map_location=trainer.device)
|
||||
trainer.model.load_state_dict(checkpoint["model_state_dict"])
|
||||
print(f"\nLoaded best model from epoch {checkpoint['epoch']} (val loss: {checkpoint['val_loss']:.4f})")
|
||||
|
||||
trainer.model.eval()
|
||||
|
||||
n_samples = 50
|
||||
names = ["rx", "ry", "rz", "tx", "ty", "scale"]
|
||||
|
||||
with torch.no_grad():
|
||||
batch = next(iter(trainer.val_loader))
|
||||
google_img = batch["google_img"].to(trainer.device)
|
||||
yandex_img = batch["yandex_img"].to(trainer.device)
|
||||
target_params = batch["homography_params"].to(trainer.device)
|
||||
|
||||
pred_params = trainer.model(google_img, yandex_img)
|
||||
|
||||
print(f"\nSample predictions (first 3 of batch):")
|
||||
print(f"{'Param':<8} {'Target':>12} {'Predicted':>12} {'Error':>12}")
|
||||
print("-" * 46)
|
||||
names = ["rx", "ry", "rz", "tx", "ty", "scale"]
|
||||
for i in range(6):
|
||||
t = target_params[0, i].item()
|
||||
p = pred_params[0, i].item()
|
||||
print(f"{names[i]:<8} {t:>12.4f} {p:>12.4f} {abs(t-p):>12.4f}")
|
||||
|
||||
print(f"\nBatch mean abs error: {torch.mean(torch.abs(pred_params - target_params)).item():.4f}")
|
||||
|
||||
print("\n=== Visualization ===")
|
||||
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
|
||||
img1 = google_img[0].cpu()
|
||||
img2 = yandex_img[0].cpu()
|
||||
axes[0].imshow(img1.permute(1, 2, 0))
|
||||
axes[0].set_title("Google")
|
||||
axes[0].axis("off")
|
||||
axes[1].imshow(img2.permute(1, 2, 0))
|
||||
axes[1].set_title("Yandex")
|
||||
axes[1].axis("off")
|
||||
axes[2].bar(names, pred_params[0].cpu().numpy())
|
||||
axes[2].set_title("Predicted params")
|
||||
axes[2].axhline(y=0, color="k", lw=0.5)
|
||||
all_errors = [[] for _ in range(6)]
|
||||
all_targets = [[] for _ in range(6)]
|
||||
all_preds = [[] for _ in range(6)]
|
||||
|
||||
for i in range(n_samples):
|
||||
try:
|
||||
batch = next(iter(trainer.val_loader))
|
||||
except StopIteration:
|
||||
break
|
||||
google_img = batch["google_img"].to(trainer.device)
|
||||
yandex_img = batch["yandex_img"].to(trainer.device)
|
||||
target_params = batch["homography_params"].to(trainer.device)
|
||||
pred_params = trainer.model(google_img, yandex_img)
|
||||
|
||||
for j in range(6):
|
||||
all_errors[j].append(torch.abs(pred_params[0, j] - target_params[0, j]).item())
|
||||
all_targets[j].append(target_params[0, j].item())
|
||||
all_preds[j].append(pred_params[0, j].item())
|
||||
|
||||
mean_errors = [np.mean(all_errors[i]) for i in range(6)]
|
||||
std_errors = [np.std(all_errors[i]) for i in range(6)]
|
||||
|
||||
if len(trainer.train_losses) > 0:
|
||||
epochs = range(1, len(trainer.train_losses) + 1)
|
||||
fig, axes = plt.subplots(2, 2, figsize=(16, 12))
|
||||
|
||||
axes[0, 0].plot(epochs, trainer.train_losses, "b-", label="Train Loss")
|
||||
axes[0, 0].plot(epochs, trainer.val_losses, "r-", label="Val Loss")
|
||||
axes[0, 0].set_xlabel("Epoch")
|
||||
axes[0, 0].set_ylabel("Loss")
|
||||
axes[0, 0].set_title("Training & Validation Loss")
|
||||
axes[0, 0].legend()
|
||||
axes[0, 0].grid(True, alpha=0.3)
|
||||
|
||||
axes[0, 1].plot(epochs, trainer.val_losses, "r-", label="Val Loss")
|
||||
axes[0, 1].set_xlabel("Epoch")
|
||||
axes[0, 1].set_ylabel("Loss")
|
||||
axes[0, 1].set_title("Validation Loss")
|
||||
axes[0, 1].legend()
|
||||
axes[0, 1].grid(True, alpha=0.3)
|
||||
|
||||
axes[1, 0].plot(epochs, trainer.val_mse_trans, "g-", label="Translation (tx, ty)")
|
||||
axes[1, 0].plot(epochs, trainer.val_mse_angle, "m-", label="Angle (rx, ry, rz)")
|
||||
axes[1, 0].plot(epochs, trainer.val_mse_scale, "c-", label="Scale")
|
||||
axes[1, 0].set_xlabel("Epoch")
|
||||
axes[1, 0].set_ylabel("MSE")
|
||||
axes[1, 0].set_title("Validation MSE by Category")
|
||||
axes[1, 0].legend()
|
||||
axes[1, 0].grid(True, alpha=0.3)
|
||||
|
||||
x_pos = np.arange(6)
|
||||
axes[1, 1].bar(x_pos, mean_errors, yerr=std_errors, capsize=5, color=["c", "m", "y", "g", "b", "r"], alpha=0.8)
|
||||
axes[1, 1].set_xticks(x_pos)
|
||||
axes[1, 1].set_xticklabels(names)
|
||||
axes[1, 1].set_ylabel("Mean Absolute Error")
|
||||
axes[1, 1].set_title(f"Mean Absolute Error per Parameter ({n_samples} samples)")
|
||||
axes[1, 1].grid(True, alpha=0.3, axis="y")
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig(os.path.join(IMG_DIR, "training_loss_plots.png"), dpi=150)
|
||||
print("Saved training_loss_plots.png")
|
||||
plt.show()
|
||||
|
||||
fig, axes = plt.subplots(2, 3, figsize=(18, 10))
|
||||
for j in range(6):
|
||||
row = j // 3
|
||||
col = j % 3
|
||||
axes[row, col].bar(range(n_samples), all_errors[j], color="steelblue", alpha=0.7)
|
||||
axes[row, col].set_xlabel("Sample")
|
||||
axes[row, col].set_ylabel("Absolute Error")
|
||||
axes[row, col].set_title(f"{names[j]}: Mean={np.mean(all_errors[j]):.4f}, Std={np.std(all_errors[j]):.4f}")
|
||||
axes[row, col].grid(True, alpha=0.3, axis="y")
|
||||
plt.suptitle(f"Mean Absolute Error per Parameter ({n_samples} samples)", fontsize=14)
|
||||
plt.tight_layout()
|
||||
plt.savefig("prediction_sample.png")
|
||||
print("Saved prediction_sample.png")
|
||||
plt.savefig(os.path.join(IMG_DIR, "mae_per_parameter.png"), dpi=150)
|
||||
print("Saved mae_per_parameter.png")
|
||||
plt.show()
|
||||
|
||||
fig, axes = plt.subplots(1, 2, figsize=(14, 6))
|
||||
|
||||
x_pos = np.arange(6)
|
||||
axes[0].bar(x_pos, mean_errors, yerr=std_errors, capsize=5, color=["c", "m", "y", "g", "b", "r"], alpha=0.8)
|
||||
axes[0].set_xticks(x_pos)
|
||||
axes[0].set_xticklabels(names)
|
||||
axes[0].set_ylabel("Mean Absolute Error")
|
||||
axes[0].set_title("Mean Absolute Error per Parameter (with std)")
|
||||
axes[0].grid(True, alpha=0.3, axis="y")
|
||||
|
||||
bp = axes[1].boxplot([all_errors[i] for i in range(6)], labels=names, patch_artist=True)
|
||||
colors = ["c", "m", "y", "g", "b", "r"]
|
||||
for patch, color in zip(bp["boxes"], colors):
|
||||
patch.set_facecolor(color)
|
||||
patch.set_alpha(0.7)
|
||||
axes[1].set_ylabel("Absolute Error")
|
||||
axes[1].set_title(f"Error Distribution per Parameter ({n_samples} samples)")
|
||||
axes[1].grid(True, alpha=0.3, axis="y")
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig(os.path.join(IMG_DIR, "mae_boxplot.png"), dpi=150)
|
||||
print("Saved mae_boxplot.png")
|
||||
plt.show()
|
||||
|
||||
print("\n=== Sample Predictions (20 pairs) ===")
|
||||
n_vis_samples = 20
|
||||
|
||||
with torch.no_grad():
|
||||
for sample_idx in range(n_vis_samples):
|
||||
try:
|
||||
batch = next(iter(trainer.val_loader))
|
||||
except StopIteration:
|
||||
break
|
||||
google_img = batch["google_img"].to(trainer.device)
|
||||
yandex_img = batch["yandex_img"].to(trainer.device)
|
||||
target_params = batch["homography_params"].to(trainer.device)
|
||||
pred_params = trainer.model(google_img, yandex_img)
|
||||
|
||||
errors = torch.abs(pred_params[0] - target_params[0]).cpu().numpy()
|
||||
targets = target_params[0].cpu().numpy()
|
||||
preds = pred_params[0].cpu().numpy()
|
||||
|
||||
fig, axes = plt.subplots(2, 2, figsize=(12, 10))
|
||||
|
||||
axes[0, 0].imshow(google_img[0].cpu().permute(1, 2, 0))
|
||||
axes[0, 0].set_title(f"Google Image")
|
||||
axes[0, 0].axis("off")
|
||||
|
||||
axes[0, 1].imshow(yandex_img[0].cpu().permute(1, 2, 0))
|
||||
axes[0, 1].set_title(f"Yandex Image")
|
||||
axes[0, 1].axis("off")
|
||||
|
||||
x_pos = np.arange(6)
|
||||
width = 0.35
|
||||
axes[1, 0].bar(x_pos - width/2, targets, width, label="Target", color="steelblue", alpha=0.8)
|
||||
axes[1, 0].bar(x_pos + width/2, preds, width, label="Predicted", color="coral", alpha=0.8)
|
||||
axes[1, 0].set_xticks(x_pos)
|
||||
axes[1, 0].set_xticklabels(names)
|
||||
axes[1, 0].set_ylabel("Parameter Value")
|
||||
axes[1, 0].set_title("Target vs Predicted")
|
||||
axes[1, 0].legend()
|
||||
axes[1, 0].grid(True, alpha=0.3, axis="y")
|
||||
|
||||
axes[1, 1].bar(x_pos, errors, color=["c", "m", "y", "g", "b", "r"], alpha=0.8)
|
||||
axes[1, 1].set_xticks(x_pos)
|
||||
axes[1, 1].set_xticklabels(names)
|
||||
axes[1, 1].set_ylabel("Absolute Error")
|
||||
axes[1, 1].set_title(f"Prediction Error (Mean: {np.mean(errors):.4f})")
|
||||
axes[1, 1].grid(True, alpha=0.3, axis="y")
|
||||
for i, e in enumerate(errors):
|
||||
axes[1, 1].text(i, e + 0.01, f"{e:.3f}", ha="center", va="bottom", fontsize=8)
|
||||
|
||||
plt.suptitle(f"Sample {sample_idx + 1}", fontsize=14)
|
||||
plt.tight_layout()
|
||||
plt.savefig(os.path.join(IMG_DIR, f"prediction_sample_{sample_idx + 1:02d}.png"), dpi=100)
|
||||
plt.show()
|
||||
print(f"Saved prediction_sample_{sample_idx + 1:02d}.png")
|
||||
|
||||
print(f"\nPrediction errors over {n_samples} samples:")
|
||||
print(f"{'Param':<8} {'Mean Error':>12} {'Std Error':>12} {'Min':>8} {'Max':>8}")
|
||||
print("-" * 52)
|
||||
for i in range(6):
|
||||
mean_err = np.mean(all_errors[i])
|
||||
std_err = np.std(all_errors[i])
|
||||
min_err = np.min(all_errors[i])
|
||||
max_err = np.max(all_errors[i])
|
||||
print(f"{names[i]:<8} {mean_err:>12.4f} {std_err:>12.4f} {min_err:>8.4f} {max_err:>8.4f}")
|
||||
|
||||
return {"best_val_loss": trainer.best_val_loss}
|
||||
return {
|
||||
"best_val_loss": trainer.best_val_loss,
|
||||
"train_losses": trainer.train_losses,
|
||||
"val_losses": trainer.val_losses,
|
||||
"val_mse_trans": trainer.val_mse_trans,
|
||||
"val_mse_angle": trainer.val_mse_angle,
|
||||
"val_mse_scale": trainer.val_mse_scale,
|
||||
}
|
||||
|
||||
@@ -14,13 +14,16 @@ from .utils import config, get_camera_matrix, generate_random_homography_params,
|
||||
|
||||
class YaGoDataset(Dataset):
|
||||
def __init__(self, root_dir: str, transform=None, augment: bool = True,
|
||||
image_size: Tuple[int, int] = (256, 256)):
|
||||
image_size: Tuple[int, int] = (256, 256), cache_level: int = 5):
|
||||
self.root_dir = root_dir
|
||||
self.transform = transform
|
||||
self.augment = augment
|
||||
self.image_size = image_size
|
||||
self.cache_level = cache_level
|
||||
self.K = get_camera_matrix(image_size[1], image_size[0])
|
||||
self.image_pairs = self._discover_image_pairs()
|
||||
self._load_images_to_memory()
|
||||
self._init_cache()
|
||||
|
||||
def _discover_image_pairs(self):
|
||||
pairs = []
|
||||
@@ -32,28 +35,72 @@ class YaGoDataset(Dataset):
|
||||
pairs.append({"idx": int(idx), "google": os.path.join(self.root_dir, f), "yandex": yandex_path})
|
||||
return sorted(pairs, key=lambda x: x["idx"])
|
||||
|
||||
def _load_images_to_memory(self):
|
||||
self._google_images = []
|
||||
self._yandex_images = []
|
||||
for pair in self.image_pairs:
|
||||
google_img = cv2.imread(pair["google"])
|
||||
google_img = cv2.cvtColor(google_img, cv2.COLOR_BGR2RGB)
|
||||
google_img = cv2.resize(google_img, (self.image_size[1], self.image_size[0]), interpolation=cv2.INTER_LINEAR)
|
||||
|
||||
yandex_img = cv2.imread(pair["yandex"])
|
||||
yandex_img = cv2.cvtColor(yandex_img, cv2.COLOR_BGR2RGB)
|
||||
yandex_img = cv2.resize(yandex_img, (self.image_size[1], self.image_size[0]), interpolation=cv2.INTER_LINEAR)
|
||||
|
||||
self._google_images.append(google_img)
|
||||
self._yandex_images.append(yandex_img)
|
||||
|
||||
def _init_cache(self):
|
||||
self._access_counts = [0] * len(self.image_pairs)
|
||||
self._cached_google = [None] * len(self.image_pairs)
|
||||
self._cached_yandex = [None] * len(self.image_pairs)
|
||||
self._cached_homography = [None] * len(self.image_pairs)
|
||||
|
||||
def _generate_augmented(self, idx):
|
||||
google_img = self._google_images[idx].copy()
|
||||
yandex_img = self._yandex_images[idx].copy()
|
||||
|
||||
params1 = generate_random_homography_params()
|
||||
params2 = generate_random_homography_params()
|
||||
H1 = homography_params_to_matrix(params1, self.K)
|
||||
H2 = homography_params_to_matrix(params2, self.K)
|
||||
H_combined = np.linalg.inv(H1) @ H2
|
||||
|
||||
google_warped = cv2.warpPerspective(google_img, H2, (self.image_size[1], self.image_size[0]))
|
||||
yandex_warped = cv2.warpPerspective(yandex_img, H1, (self.image_size[1], self.image_size[0]))
|
||||
|
||||
target_params = matrix_to_homography_params(H_combined, self.K)
|
||||
|
||||
return google_warped, yandex_warped, H_combined, target_params
|
||||
|
||||
def __len__(self):
|
||||
return len(self.image_pairs)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
pair = self.image_pairs[idx]
|
||||
google_img = Image.open(pair["google"]).convert("RGB").resize((self.image_size[1], self.image_size[0]), Image.BILINEAR)
|
||||
yandex_img = Image.open(pair["yandex"]).convert("RGB").resize((self.image_size[1], self.image_size[0]), Image.BILINEAR)
|
||||
|
||||
if self.augment:
|
||||
params1 = generate_random_homography_params()
|
||||
params2 = generate_random_homography_params()
|
||||
H1 = homography_params_to_matrix(params1, self.K)
|
||||
H2 = homography_params_to_matrix(params2, self.K)
|
||||
H_combined = np.linalg.inv(H1) @ H2
|
||||
yandex_img = Image.fromarray(cv2.warpPerspective(np.array(yandex_img), H1, self.image_size))
|
||||
google_img = Image.fromarray(cv2.warpPerspective(np.array(google_img), H2, self.image_size))
|
||||
target_params = matrix_to_homography_params(H_combined, self.K)
|
||||
target_matrix = H_combined
|
||||
self._access_counts[idx] += 1
|
||||
|
||||
use_cache = self.augment and self.cache_level > 0 and self._access_counts[idx] > 1 and (self._access_counts[idx] - 1) % self.cache_level != 0
|
||||
|
||||
if use_cache:
|
||||
google_img = self._cached_google[idx]
|
||||
yandex_img = self._cached_yandex[idx]
|
||||
target_matrix = self._cached_homography[idx]
|
||||
target_params = matrix_to_homography_params(target_matrix, self.K)
|
||||
elif self.augment:
|
||||
google_img, yandex_img, target_matrix, target_params = self._generate_augmented(idx)
|
||||
if self.cache_level > 0:
|
||||
self._cached_google[idx] = google_img
|
||||
self._cached_yandex[idx] = yandex_img
|
||||
self._cached_homography[idx] = target_matrix
|
||||
else:
|
||||
google_img = self._google_images[idx]
|
||||
yandex_img = self._yandex_images[idx]
|
||||
target_params = np.zeros(6, dtype=np.float32)
|
||||
target_matrix = np.eye(3, dtype=np.float32)
|
||||
|
||||
google_img = Image.fromarray(google_img)
|
||||
yandex_img = Image.fromarray(yandex_img)
|
||||
|
||||
if self.transform:
|
||||
google_img = self.transform(google_img)
|
||||
yandex_img = self.transform(yandex_img)
|
||||
@@ -67,11 +114,14 @@ class YaGoDataset(Dataset):
|
||||
|
||||
|
||||
def create_data_loaders(root_dir, batch_size=32, train_split=0.8, num_workers=0,
|
||||
image_size=(256, 256), augment_train=True):
|
||||
transform = transforms.Compose([transforms.ToTensor()])
|
||||
image_size=(256, 256), augment_train=True, cache_level=5):
|
||||
transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
# transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
||||
])
|
||||
|
||||
full_ds = YaGoDataset(root_dir, transform=transform, augment=False, image_size=image_size)
|
||||
aug_ds = YaGoDataset(root_dir, transform=transform, augment=True, image_size=image_size)
|
||||
full_ds = YaGoDataset(root_dir, transform=transform, augment=False, image_size=image_size, cache_level=cache_level)
|
||||
aug_ds = YaGoDataset(root_dir, transform=transform, augment=True, image_size=image_size, cache_level=cache_level)
|
||||
|
||||
indices = list(range(len(full_ds)))
|
||||
random.shuffle(indices)
|
||||
|
||||
@@ -21,10 +21,19 @@ class HomographyTrainer:
|
||||
self.optimizer = optim.Adam(model.parameters(), lr=config["learning_rate"])
|
||||
self.writer = None
|
||||
self.best_val_loss = float("inf")
|
||||
self.train_losses = []
|
||||
self.val_losses = []
|
||||
self.train_mse_trans = []
|
||||
self.train_mse_angle = []
|
||||
self.train_mse_scale = []
|
||||
self.val_mse_trans = []
|
||||
self.val_mse_angle = []
|
||||
self.val_mse_scale = []
|
||||
|
||||
def train_epoch(self, epoch):
|
||||
self.model.train()
|
||||
total_loss, total_samples = 0, 0
|
||||
mse_trans_sum, mse_angle_sum, mse_scale_sum = 0, 0, 0
|
||||
pbar = tqdm(self.train_loader, desc=f"Epoch {epoch}")
|
||||
for batch_idx, batch in enumerate(pbar):
|
||||
google_img = batch["google_img"].to(self.device)
|
||||
@@ -39,13 +48,23 @@ class HomographyTrainer:
|
||||
|
||||
total_loss += loss.item() * google_img.size(0)
|
||||
total_samples += google_img.size(0)
|
||||
|
||||
mse_trans_sum += torch.mean((output[:, 3:5] - target[:, 3:5]) ** 2).item() * google_img.size(0)
|
||||
mse_angle_sum += torch.mean((output[:, 0:3] - target[:, 0:3]) ** 2).item() * google_img.size(0)
|
||||
mse_scale_sum += torch.mean((output[:, 5:6] - target[:, 5:6]) ** 2).item() * google_img.size(0)
|
||||
|
||||
pbar.set_postfix({"loss": loss.item()})
|
||||
|
||||
self.train_mse_trans.append(mse_trans_sum / total_samples)
|
||||
self.train_mse_angle.append(mse_angle_sum / total_samples)
|
||||
self.train_mse_scale.append(mse_scale_sum / total_samples)
|
||||
|
||||
return {"loss": total_loss / total_samples}
|
||||
|
||||
def validate(self):
|
||||
self.model.eval()
|
||||
total_loss, total_samples = 0, 0
|
||||
mse_trans_sum, mse_angle_sum, mse_scale_sum = 0, 0, 0
|
||||
with torch.no_grad():
|
||||
for batch in tqdm(self.val_loader, desc="Validation"):
|
||||
google_img = batch["google_img"].to(self.device)
|
||||
@@ -55,6 +74,15 @@ class HomographyTrainer:
|
||||
loss = self.criterion(output, target)
|
||||
total_loss += loss.item() * google_img.size(0)
|
||||
total_samples += google_img.size(0)
|
||||
|
||||
mse_trans_sum += torch.mean((output[:, 3:5] - target[:, 3:5]) ** 2).item() * google_img.size(0)
|
||||
mse_angle_sum += torch.mean((output[:, 0:3] - target[:, 0:3]) ** 2).item() * google_img.size(0)
|
||||
mse_scale_sum += torch.mean((output[:, 5:6] - target[:, 5:6]) ** 2).item() * google_img.size(0)
|
||||
|
||||
self.val_mse_trans.append(mse_trans_sum / total_samples)
|
||||
self.val_mse_angle.append(mse_angle_sum / total_samples)
|
||||
self.val_mse_scale.append(mse_scale_sum / total_samples)
|
||||
|
||||
return {"loss": total_loss / total_samples}
|
||||
|
||||
def train(self, num_epochs):
|
||||
@@ -65,7 +93,10 @@ class HomographyTrainer:
|
||||
for epoch in range(1, num_epochs + 1):
|
||||
train_metrics = self.train_epoch(epoch)
|
||||
val_metrics = self.validate()
|
||||
self.train_losses.append(train_metrics["loss"])
|
||||
self.val_losses.append(val_metrics["loss"])
|
||||
print(f"Train Loss: {train_metrics['loss']:.4f}, Val Loss: {val_metrics['loss']:.4f}")
|
||||
print(f" MSE - Trans: {self.val_mse_trans[-1]:.4f}, Angle: {self.val_mse_angle[-1]:.4f}, Scale: {self.val_mse_scale[-1]:.4f}")
|
||||
|
||||
if val_metrics["loss"] < self.best_val_loss:
|
||||
self.best_val_loss = val_metrics["loss"]
|
||||
|
||||
Reference in New Issue
Block a user