feat(SiaN): improve analyzing

This commit is contained in:
2026-04-05 12:50:28 +03:00
parent ec8b3ae20e
commit fa4c4b83ae
6 changed files with 1148 additions and 543 deletions

View File

@@ -1,3 +1,3 @@
runs
*.gen.py
*.img
*.png

View File

@@ -18,7 +18,6 @@ from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms, models
from tqdm import tqdm
# code: ./src/utils.py
# markdown
"""# Configuration
@@ -28,8 +27,8 @@ Global settings for:
- Model architecture options
Contains the `config` dictionary used across all modules."""
# code: ./src/utils.py
# code: ./src/dataloader.py
# markdown
"""## Dataset
@@ -42,8 +41,8 @@ Google/Yandex image pair loader with homography augmentation.
**Returns:**
- Batch dict with `google_img`, `yandex_img`, `homography_params`"""
# code: ./src/dataloader.py
# code: ./src/model.py
# markdown
"""## Model
@@ -58,8 +57,8 @@ Google/Yandex image pair loader with homography augmentation.
- Dual-branch CNN (Google + Yandex images)
- Shared backbone (configurable: resnet18/34/50)
- Fusion head with dropout regularization"""
# code: ./src/model.py
# code: ./src/train.py
# markdown
"""## Training
@@ -75,8 +74,8 @@ Google/Yandex image pair loader with homography augmentation.
**Checkpoint saving:**
- `best_model.pt` — lowest validation loss
- `checkpoint_epoch_N.pt` — periodic saves"""
# code: ./src/train.py
# code: ./src/analyze.py
# markdown
"""## Analysis
@@ -85,8 +84,8 @@ Visualization and evaluation tools:
- Training metrics plots (loss curves)
- Prediction visualization on sample images
- Error analysis and statistics"""
# code: ./src/analyze.py
# code: ./src/main.py
# markdown
"""## Main Pipeline
@@ -101,6 +100,7 @@ Executes the full training workflow:
- Model checkpoints in `runs/checkpoints/`
- TensorBoard logs in `runs/`
- Analysis plots"""
# code: ./src/main.py
# # shell:
# !zip artefacts.zip runs/gan_training/checkpoints/best_model.pt
# !zip artefacts.zip runs/checkpoints/best_model.pt runs/images/ runs/events.*

File diff suppressed because it is too large Load Diff

View File

@@ -1,7 +1,13 @@
import os
import torch
import numpy as np
import matplotlib.pyplot as plt
from .utils import config
IMG_DIR = os.path.join(config["output_dir"], "images")
os.makedirs(IMG_DIR, exist_ok=True)
def analyze_training(trainer):
print("=== Training Analysis ===\n")
@@ -11,42 +17,189 @@ def analyze_training(trainer):
print(f"\nBest val loss: {trainer.best_val_loss:.4f}")
best_model_path = os.path.join(config["output_dir"], "checkpoints", "best_model.pt")
if os.path.exists(best_model_path):
checkpoint = torch.load(best_model_path, map_location=trainer.device)
trainer.model.load_state_dict(checkpoint["model_state_dict"])
print(f"\nLoaded best model from epoch {checkpoint['epoch']} (val loss: {checkpoint['val_loss']:.4f})")
trainer.model.eval()
n_samples = 50
names = ["rx", "ry", "rz", "tx", "ty", "scale"]
with torch.no_grad():
batch = next(iter(trainer.val_loader))
google_img = batch["google_img"].to(trainer.device)
yandex_img = batch["yandex_img"].to(trainer.device)
target_params = batch["homography_params"].to(trainer.device)
pred_params = trainer.model(google_img, yandex_img)
print(f"\nSample predictions (first 3 of batch):")
print(f"{'Param':<8} {'Target':>12} {'Predicted':>12} {'Error':>12}")
print("-" * 46)
names = ["rx", "ry", "rz", "tx", "ty", "scale"]
for i in range(6):
t = target_params[0, i].item()
p = pred_params[0, i].item()
print(f"{names[i]:<8} {t:>12.4f} {p:>12.4f} {abs(t-p):>12.4f}")
print(f"\nBatch mean abs error: {torch.mean(torch.abs(pred_params - target_params)).item():.4f}")
print("\n=== Visualization ===")
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
img1 = google_img[0].cpu()
img2 = yandex_img[0].cpu()
axes[0].imshow(img1.permute(1, 2, 0))
axes[0].set_title("Google")
axes[0].axis("off")
axes[1].imshow(img2.permute(1, 2, 0))
axes[1].set_title("Yandex")
axes[1].axis("off")
axes[2].bar(names, pred_params[0].cpu().numpy())
axes[2].set_title("Predicted params")
axes[2].axhline(y=0, color="k", lw=0.5)
all_errors = [[] for _ in range(6)]
all_targets = [[] for _ in range(6)]
all_preds = [[] for _ in range(6)]
for i in range(n_samples):
try:
batch = next(iter(trainer.val_loader))
except StopIteration:
break
google_img = batch["google_img"].to(trainer.device)
yandex_img = batch["yandex_img"].to(trainer.device)
target_params = batch["homography_params"].to(trainer.device)
pred_params = trainer.model(google_img, yandex_img)
for j in range(6):
all_errors[j].append(torch.abs(pred_params[0, j] - target_params[0, j]).item())
all_targets[j].append(target_params[0, j].item())
all_preds[j].append(pred_params[0, j].item())
mean_errors = [np.mean(all_errors[i]) for i in range(6)]
std_errors = [np.std(all_errors[i]) for i in range(6)]
if len(trainer.train_losses) > 0:
epochs = range(1, len(trainer.train_losses) + 1)
fig, axes = plt.subplots(2, 2, figsize=(16, 12))
axes[0, 0].plot(epochs, trainer.train_losses, "b-", label="Train Loss")
axes[0, 0].plot(epochs, trainer.val_losses, "r-", label="Val Loss")
axes[0, 0].set_xlabel("Epoch")
axes[0, 0].set_ylabel("Loss")
axes[0, 0].set_title("Training & Validation Loss")
axes[0, 0].legend()
axes[0, 0].grid(True, alpha=0.3)
axes[0, 1].plot(epochs, trainer.val_losses, "r-", label="Val Loss")
axes[0, 1].set_xlabel("Epoch")
axes[0, 1].set_ylabel("Loss")
axes[0, 1].set_title("Validation Loss")
axes[0, 1].legend()
axes[0, 1].grid(True, alpha=0.3)
axes[1, 0].plot(epochs, trainer.val_mse_trans, "g-", label="Translation (tx, ty)")
axes[1, 0].plot(epochs, trainer.val_mse_angle, "m-", label="Angle (rx, ry, rz)")
axes[1, 0].plot(epochs, trainer.val_mse_scale, "c-", label="Scale")
axes[1, 0].set_xlabel("Epoch")
axes[1, 0].set_ylabel("MSE")
axes[1, 0].set_title("Validation MSE by Category")
axes[1, 0].legend()
axes[1, 0].grid(True, alpha=0.3)
x_pos = np.arange(6)
axes[1, 1].bar(x_pos, mean_errors, yerr=std_errors, capsize=5, color=["c", "m", "y", "g", "b", "r"], alpha=0.8)
axes[1, 1].set_xticks(x_pos)
axes[1, 1].set_xticklabels(names)
axes[1, 1].set_ylabel("Mean Absolute Error")
axes[1, 1].set_title(f"Mean Absolute Error per Parameter ({n_samples} samples)")
axes[1, 1].grid(True, alpha=0.3, axis="y")
plt.tight_layout()
plt.savefig(os.path.join(IMG_DIR, "training_loss_plots.png"), dpi=150)
print("Saved training_loss_plots.png")
plt.show()
fig, axes = plt.subplots(2, 3, figsize=(18, 10))
for j in range(6):
row = j // 3
col = j % 3
axes[row, col].bar(range(n_samples), all_errors[j], color="steelblue", alpha=0.7)
axes[row, col].set_xlabel("Sample")
axes[row, col].set_ylabel("Absolute Error")
axes[row, col].set_title(f"{names[j]}: Mean={np.mean(all_errors[j]):.4f}, Std={np.std(all_errors[j]):.4f}")
axes[row, col].grid(True, alpha=0.3, axis="y")
plt.suptitle(f"Mean Absolute Error per Parameter ({n_samples} samples)", fontsize=14)
plt.tight_layout()
plt.savefig("prediction_sample.png")
print("Saved prediction_sample.png")
plt.savefig(os.path.join(IMG_DIR, "mae_per_parameter.png"), dpi=150)
print("Saved mae_per_parameter.png")
plt.show()
fig, axes = plt.subplots(1, 2, figsize=(14, 6))
x_pos = np.arange(6)
axes[0].bar(x_pos, mean_errors, yerr=std_errors, capsize=5, color=["c", "m", "y", "g", "b", "r"], alpha=0.8)
axes[0].set_xticks(x_pos)
axes[0].set_xticklabels(names)
axes[0].set_ylabel("Mean Absolute Error")
axes[0].set_title("Mean Absolute Error per Parameter (with std)")
axes[0].grid(True, alpha=0.3, axis="y")
bp = axes[1].boxplot([all_errors[i] for i in range(6)], labels=names, patch_artist=True)
colors = ["c", "m", "y", "g", "b", "r"]
for patch, color in zip(bp["boxes"], colors):
patch.set_facecolor(color)
patch.set_alpha(0.7)
axes[1].set_ylabel("Absolute Error")
axes[1].set_title(f"Error Distribution per Parameter ({n_samples} samples)")
axes[1].grid(True, alpha=0.3, axis="y")
plt.tight_layout()
plt.savefig(os.path.join(IMG_DIR, "mae_boxplot.png"), dpi=150)
print("Saved mae_boxplot.png")
plt.show()
print("\n=== Sample Predictions (20 pairs) ===")
n_vis_samples = 20
with torch.no_grad():
for sample_idx in range(n_vis_samples):
try:
batch = next(iter(trainer.val_loader))
except StopIteration:
break
google_img = batch["google_img"].to(trainer.device)
yandex_img = batch["yandex_img"].to(trainer.device)
target_params = batch["homography_params"].to(trainer.device)
pred_params = trainer.model(google_img, yandex_img)
errors = torch.abs(pred_params[0] - target_params[0]).cpu().numpy()
targets = target_params[0].cpu().numpy()
preds = pred_params[0].cpu().numpy()
fig, axes = plt.subplots(2, 2, figsize=(12, 10))
axes[0, 0].imshow(google_img[0].cpu().permute(1, 2, 0))
axes[0, 0].set_title(f"Google Image")
axes[0, 0].axis("off")
axes[0, 1].imshow(yandex_img[0].cpu().permute(1, 2, 0))
axes[0, 1].set_title(f"Yandex Image")
axes[0, 1].axis("off")
x_pos = np.arange(6)
width = 0.35
axes[1, 0].bar(x_pos - width/2, targets, width, label="Target", color="steelblue", alpha=0.8)
axes[1, 0].bar(x_pos + width/2, preds, width, label="Predicted", color="coral", alpha=0.8)
axes[1, 0].set_xticks(x_pos)
axes[1, 0].set_xticklabels(names)
axes[1, 0].set_ylabel("Parameter Value")
axes[1, 0].set_title("Target vs Predicted")
axes[1, 0].legend()
axes[1, 0].grid(True, alpha=0.3, axis="y")
axes[1, 1].bar(x_pos, errors, color=["c", "m", "y", "g", "b", "r"], alpha=0.8)
axes[1, 1].set_xticks(x_pos)
axes[1, 1].set_xticklabels(names)
axes[1, 1].set_ylabel("Absolute Error")
axes[1, 1].set_title(f"Prediction Error (Mean: {np.mean(errors):.4f})")
axes[1, 1].grid(True, alpha=0.3, axis="y")
for i, e in enumerate(errors):
axes[1, 1].text(i, e + 0.01, f"{e:.3f}", ha="center", va="bottom", fontsize=8)
plt.suptitle(f"Sample {sample_idx + 1}", fontsize=14)
plt.tight_layout()
plt.savefig(os.path.join(IMG_DIR, f"prediction_sample_{sample_idx + 1:02d}.png"), dpi=100)
plt.show()
print(f"Saved prediction_sample_{sample_idx + 1:02d}.png")
print(f"\nPrediction errors over {n_samples} samples:")
print(f"{'Param':<8} {'Mean Error':>12} {'Std Error':>12} {'Min':>8} {'Max':>8}")
print("-" * 52)
for i in range(6):
mean_err = np.mean(all_errors[i])
std_err = np.std(all_errors[i])
min_err = np.min(all_errors[i])
max_err = np.max(all_errors[i])
print(f"{names[i]:<8} {mean_err:>12.4f} {std_err:>12.4f} {min_err:>8.4f} {max_err:>8.4f}")
return {"best_val_loss": trainer.best_val_loss}
return {
"best_val_loss": trainer.best_val_loss,
"train_losses": trainer.train_losses,
"val_losses": trainer.val_losses,
"val_mse_trans": trainer.val_mse_trans,
"val_mse_angle": trainer.val_mse_angle,
"val_mse_scale": trainer.val_mse_scale,
}

View File

@@ -14,13 +14,16 @@ from .utils import config, get_camera_matrix, generate_random_homography_params,
class YaGoDataset(Dataset):
def __init__(self, root_dir: str, transform=None, augment: bool = True,
image_size: Tuple[int, int] = (256, 256)):
image_size: Tuple[int, int] = (256, 256), cache_level: int = 5):
self.root_dir = root_dir
self.transform = transform
self.augment = augment
self.image_size = image_size
self.cache_level = cache_level
self.K = get_camera_matrix(image_size[1], image_size[0])
self.image_pairs = self._discover_image_pairs()
self._load_images_to_memory()
self._init_cache()
def _discover_image_pairs(self):
pairs = []
@@ -32,28 +35,72 @@ class YaGoDataset(Dataset):
pairs.append({"idx": int(idx), "google": os.path.join(self.root_dir, f), "yandex": yandex_path})
return sorted(pairs, key=lambda x: x["idx"])
def _load_images_to_memory(self):
self._google_images = []
self._yandex_images = []
for pair in self.image_pairs:
google_img = cv2.imread(pair["google"])
google_img = cv2.cvtColor(google_img, cv2.COLOR_BGR2RGB)
google_img = cv2.resize(google_img, (self.image_size[1], self.image_size[0]), interpolation=cv2.INTER_LINEAR)
yandex_img = cv2.imread(pair["yandex"])
yandex_img = cv2.cvtColor(yandex_img, cv2.COLOR_BGR2RGB)
yandex_img = cv2.resize(yandex_img, (self.image_size[1], self.image_size[0]), interpolation=cv2.INTER_LINEAR)
self._google_images.append(google_img)
self._yandex_images.append(yandex_img)
def _init_cache(self):
self._access_counts = [0] * len(self.image_pairs)
self._cached_google = [None] * len(self.image_pairs)
self._cached_yandex = [None] * len(self.image_pairs)
self._cached_homography = [None] * len(self.image_pairs)
def _generate_augmented(self, idx):
google_img = self._google_images[idx].copy()
yandex_img = self._yandex_images[idx].copy()
params1 = generate_random_homography_params()
params2 = generate_random_homography_params()
H1 = homography_params_to_matrix(params1, self.K)
H2 = homography_params_to_matrix(params2, self.K)
H_combined = np.linalg.inv(H1) @ H2
google_warped = cv2.warpPerspective(google_img, H2, (self.image_size[1], self.image_size[0]))
yandex_warped = cv2.warpPerspective(yandex_img, H1, (self.image_size[1], self.image_size[0]))
target_params = matrix_to_homography_params(H_combined, self.K)
return google_warped, yandex_warped, H_combined, target_params
def __len__(self):
return len(self.image_pairs)
def __getitem__(self, idx):
pair = self.image_pairs[idx]
google_img = Image.open(pair["google"]).convert("RGB").resize((self.image_size[1], self.image_size[0]), Image.BILINEAR)
yandex_img = Image.open(pair["yandex"]).convert("RGB").resize((self.image_size[1], self.image_size[0]), Image.BILINEAR)
if self.augment:
params1 = generate_random_homography_params()
params2 = generate_random_homography_params()
H1 = homography_params_to_matrix(params1, self.K)
H2 = homography_params_to_matrix(params2, self.K)
H_combined = np.linalg.inv(H1) @ H2
yandex_img = Image.fromarray(cv2.warpPerspective(np.array(yandex_img), H1, self.image_size))
google_img = Image.fromarray(cv2.warpPerspective(np.array(google_img), H2, self.image_size))
target_params = matrix_to_homography_params(H_combined, self.K)
target_matrix = H_combined
self._access_counts[idx] += 1
use_cache = self.augment and self.cache_level > 0 and self._access_counts[idx] > 1 and (self._access_counts[idx] - 1) % self.cache_level != 0
if use_cache:
google_img = self._cached_google[idx]
yandex_img = self._cached_yandex[idx]
target_matrix = self._cached_homography[idx]
target_params = matrix_to_homography_params(target_matrix, self.K)
elif self.augment:
google_img, yandex_img, target_matrix, target_params = self._generate_augmented(idx)
if self.cache_level > 0:
self._cached_google[idx] = google_img
self._cached_yandex[idx] = yandex_img
self._cached_homography[idx] = target_matrix
else:
google_img = self._google_images[idx]
yandex_img = self._yandex_images[idx]
target_params = np.zeros(6, dtype=np.float32)
target_matrix = np.eye(3, dtype=np.float32)
google_img = Image.fromarray(google_img)
yandex_img = Image.fromarray(yandex_img)
if self.transform:
google_img = self.transform(google_img)
yandex_img = self.transform(yandex_img)
@@ -67,11 +114,14 @@ class YaGoDataset(Dataset):
def create_data_loaders(root_dir, batch_size=32, train_split=0.8, num_workers=0,
image_size=(256, 256), augment_train=True):
transform = transforms.Compose([transforms.ToTensor()])
image_size=(256, 256), augment_train=True, cache_level=5):
transform = transforms.Compose([
transforms.ToTensor(),
# transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
full_ds = YaGoDataset(root_dir, transform=transform, augment=False, image_size=image_size)
aug_ds = YaGoDataset(root_dir, transform=transform, augment=True, image_size=image_size)
full_ds = YaGoDataset(root_dir, transform=transform, augment=False, image_size=image_size, cache_level=cache_level)
aug_ds = YaGoDataset(root_dir, transform=transform, augment=True, image_size=image_size, cache_level=cache_level)
indices = list(range(len(full_ds)))
random.shuffle(indices)

View File

@@ -21,10 +21,19 @@ class HomographyTrainer:
self.optimizer = optim.Adam(model.parameters(), lr=config["learning_rate"])
self.writer = None
self.best_val_loss = float("inf")
self.train_losses = []
self.val_losses = []
self.train_mse_trans = []
self.train_mse_angle = []
self.train_mse_scale = []
self.val_mse_trans = []
self.val_mse_angle = []
self.val_mse_scale = []
def train_epoch(self, epoch):
self.model.train()
total_loss, total_samples = 0, 0
mse_trans_sum, mse_angle_sum, mse_scale_sum = 0, 0, 0
pbar = tqdm(self.train_loader, desc=f"Epoch {epoch}")
for batch_idx, batch in enumerate(pbar):
google_img = batch["google_img"].to(self.device)
@@ -39,13 +48,23 @@ class HomographyTrainer:
total_loss += loss.item() * google_img.size(0)
total_samples += google_img.size(0)
mse_trans_sum += torch.mean((output[:, 3:5] - target[:, 3:5]) ** 2).item() * google_img.size(0)
mse_angle_sum += torch.mean((output[:, 0:3] - target[:, 0:3]) ** 2).item() * google_img.size(0)
mse_scale_sum += torch.mean((output[:, 5:6] - target[:, 5:6]) ** 2).item() * google_img.size(0)
pbar.set_postfix({"loss": loss.item()})
self.train_mse_trans.append(mse_trans_sum / total_samples)
self.train_mse_angle.append(mse_angle_sum / total_samples)
self.train_mse_scale.append(mse_scale_sum / total_samples)
return {"loss": total_loss / total_samples}
def validate(self):
self.model.eval()
total_loss, total_samples = 0, 0
mse_trans_sum, mse_angle_sum, mse_scale_sum = 0, 0, 0
with torch.no_grad():
for batch in tqdm(self.val_loader, desc="Validation"):
google_img = batch["google_img"].to(self.device)
@@ -55,6 +74,15 @@ class HomographyTrainer:
loss = self.criterion(output, target)
total_loss += loss.item() * google_img.size(0)
total_samples += google_img.size(0)
mse_trans_sum += torch.mean((output[:, 3:5] - target[:, 3:5]) ** 2).item() * google_img.size(0)
mse_angle_sum += torch.mean((output[:, 0:3] - target[:, 0:3]) ** 2).item() * google_img.size(0)
mse_scale_sum += torch.mean((output[:, 5:6] - target[:, 5:6]) ** 2).item() * google_img.size(0)
self.val_mse_trans.append(mse_trans_sum / total_samples)
self.val_mse_angle.append(mse_angle_sum / total_samples)
self.val_mse_scale.append(mse_scale_sum / total_samples)
return {"loss": total_loss / total_samples}
def train(self, num_epochs):
@@ -65,7 +93,10 @@ class HomographyTrainer:
for epoch in range(1, num_epochs + 1):
train_metrics = self.train_epoch(epoch)
val_metrics = self.validate()
self.train_losses.append(train_metrics["loss"])
self.val_losses.append(val_metrics["loss"])
print(f"Train Loss: {train_metrics['loss']:.4f}, Val Loss: {val_metrics['loss']:.4f}")
print(f" MSE - Trans: {self.val_mse_trans[-1]:.4f}, Angle: {self.val_mse_angle[-1]:.4f}, Scale: {self.val_mse_scale[-1]:.4f}")
if val_metrics["loss"] < self.best_val_loss:
self.best_val_loss = val_metrics["loss"]