917 lines
33 KiB
Python
917 lines
33 KiB
Python
|
||
import os
|
||
import time
|
||
from datetime import datetime
|
||
|
||
import torch
|
||
import torch.nn as nn
|
||
import torch.optim as optim
|
||
from torch.utils.data import DataLoader
|
||
from torch.utils.tensorboard import SummaryWriter
|
||
from tqdm import tqdm
|
||
|
||
# =============================================================================
|
||
# TRAINING LOOP WITH VISUALIZATION
|
||
# =============================================================================
|
||
|
||
class SimilarityTrainer:
|
||
def __init__(
|
||
self,
|
||
model: nn.Module,
|
||
trainloader: DataLoader,
|
||
valloader: DataLoader,
|
||
device: torch.device,
|
||
config: dict,
|
||
):
|
||
self.model = model.to(device)
|
||
self.trainloader = trainloader
|
||
self.valloader = valloader
|
||
self.device = device
|
||
self.config = config
|
||
|
||
self.criterion = SimilarityLoss()
|
||
self.optimizer = optim.Adam(
|
||
model.parameters(),
|
||
lr=config.get('learning_rate', 2e-4),
|
||
betas=(config.get('beta1', 0.5), config.get('beta2', 0.999))
|
||
)
|
||
|
||
self.writer = None
|
||
self.best_val_loss = float('inf')
|
||
self.epochs_without_improvement = 0
|
||
|
||
# Для хранения истории метрик
|
||
self.history = {
|
||
'train_loss': [],
|
||
'val_loss': [],
|
||
'val_accuracy': [],
|
||
'val_precision': [],
|
||
'val_recall': [],
|
||
'val_f1': [],
|
||
'learning_rate': []
|
||
}
|
||
|
||
def train_epoch(self, epoch: int) -> dict:
|
||
"""Обучение на одной эпохе"""
|
||
self.model.train()
|
||
total_loss = 0
|
||
total_samples = 0
|
||
all_metrics = []
|
||
|
||
pbar = tqdm(self.trainloader, desc=f'Epoch {epoch}')
|
||
for batch_idx, batch in enumerate(pbar):
|
||
google_img = batch['google_img'].to(self.device)
|
||
yandex_img = batch['yandex_img'].to(self.device)
|
||
target = batch['same_domain'].float().to(self.device).unsqueeze(1)
|
||
|
||
self.optimizer.zero_grad()
|
||
|
||
# Forward pass
|
||
output = self.model(google_img, yandex_img)
|
||
loss = self.criterion(output, target)
|
||
|
||
# Backward pass
|
||
loss.backward()
|
||
|
||
# Gradient clipping
|
||
if self.config.get('grad_clip', None):
|
||
torch.nn.utils.clip_grad_norm_(
|
||
self.model.parameters(),
|
||
self.config['grad_clip']
|
||
)
|
||
|
||
self.optimizer.step()
|
||
|
||
total_loss += loss.item() * google_img.size(0)
|
||
total_samples += google_img.size(0)
|
||
|
||
# Compute metrics
|
||
if batch_idx % self.config.get('log_interval', 10) == 0:
|
||
metrics = self.criterion.compute_metrics(output, target)
|
||
all_metrics.append(metrics)
|
||
pbar.set_postfix({
|
||
'loss': f"{loss.item():.4f}",
|
||
'acc': f"{metrics['accuracy']:.4f}"
|
||
})
|
||
|
||
if self.writer:
|
||
global_step = epoch * len(self.trainloader) + batch_idx
|
||
self.writer.add_scalar('train/loss', loss.item(), global_step)
|
||
self.writer.add_scalar('train/accuracy', metrics['accuracy'], global_step)
|
||
|
||
avg_loss = total_loss / total_samples
|
||
|
||
# Average metrics
|
||
if all_metrics:
|
||
avg_metrics = {
|
||
key: sum(m[key] for m in all_metrics) / len(all_metrics)
|
||
for key in all_metrics[0].keys()
|
||
}
|
||
else:
|
||
avg_metrics = {}
|
||
|
||
return {'loss': avg_loss, **avg_metrics}
|
||
|
||
def validate(self) -> dict:
|
||
"""Валидация модели"""
|
||
self.model.eval()
|
||
total_loss = 0
|
||
total_samples = 0
|
||
all_metrics = []
|
||
|
||
# Для ROC и confusion matrix
|
||
all_predictions = []
|
||
all_targets = []
|
||
|
||
with torch.no_grad():
|
||
for batch in tqdm(self.valloader, desc='Validation'):
|
||
google_img = batch['google_img'].to(self.device)
|
||
yandex_img = batch['yandex_img'].to(self.device)
|
||
target = batch['same_domain'].float().to(self.device).unsqueeze(1)
|
||
|
||
output = self.model(google_img, yandex_img)
|
||
loss = self.criterion(output, target)
|
||
|
||
total_loss += loss.item() * google_img.size(0)
|
||
total_samples += google_img.size(0)
|
||
|
||
metrics = self.criterion.compute_metrics(output, target)
|
||
all_metrics.append(metrics)
|
||
|
||
all_predictions.append(output.cpu())
|
||
all_targets.append(target.cpu())
|
||
|
||
avg_loss = total_loss / total_samples
|
||
avg_metrics = {
|
||
key: sum(m[key] for m in all_metrics) / len(all_metrics)
|
||
for key in all_metrics[0].keys()
|
||
}
|
||
|
||
# Concatenate all predictions and targets
|
||
all_predictions = torch.cat(all_predictions, dim=0)
|
||
all_targets = torch.cat(all_targets, dim=0)
|
||
|
||
return {
|
||
'loss': avg_loss,
|
||
**avg_metrics,
|
||
'predictions': all_predictions,
|
||
'targets': all_targets
|
||
}
|
||
|
||
def train(self, num_epochs: int):
|
||
"""Основной цикл обучения"""
|
||
log_dir = os.path.join(self.config.get('output_dir', 'runs/similarity'))
|
||
os.makedirs(log_dir, exist_ok=True)
|
||
self.writer = SummaryWriter(log_dir)
|
||
|
||
print(f'\n{"="*70}')
|
||
print(f'Starting training for {num_epochs} epochs')
|
||
print(f'Logging to {log_dir}')
|
||
print(f'{"="*70}\n')
|
||
|
||
start_time = time.time()
|
||
|
||
for epoch in range(1, num_epochs + 1):
|
||
epoch_start = time.time()
|
||
print(f'\n--- Epoch {epoch}/{num_epochs} ---')
|
||
|
||
# Train
|
||
train_metrics = self.train_epoch(epoch)
|
||
|
||
# Validate
|
||
val_metrics = self.validate()
|
||
|
||
# Store history
|
||
self.history['train_loss'].append(train_metrics['loss'])
|
||
self.history['val_loss'].append(val_metrics['loss'])
|
||
self.history['val_accuracy'].append(val_metrics['accuracy'])
|
||
self.history['val_precision'].append(val_metrics['precision'])
|
||
self.history['val_recall'].append(val_metrics['recall'])
|
||
self.history['val_f1'].append(val_metrics['f1'])
|
||
self.history['learning_rate'].append(
|
||
self.optimizer.param_groups[0]['lr']
|
||
)
|
||
|
||
# Print metrics
|
||
print(f'\nTrain Loss: {train_metrics["loss"]:.4f}')
|
||
print(f'Val Loss: {val_metrics["loss"]:.4f}')
|
||
print(f'Val Accuracy: {val_metrics["accuracy"]:.4f}')
|
||
print(f'Val Precision: {val_metrics["precision"]:.4f}')
|
||
print(f'Val Recall: {val_metrics["recall"]:.4f}')
|
||
print(f'Val F1: {val_metrics["f1"]:.4f}')
|
||
|
||
epoch_time = time.time() - epoch_start
|
||
print(f'Epoch time: {epoch_time:.2f}s')
|
||
|
||
# TensorBoard logging
|
||
if self.writer:
|
||
self.writer.add_scalar('epoch/train_loss', train_metrics['loss'], epoch)
|
||
self.writer.add_scalar('epoch/val_loss', val_metrics['loss'], epoch)
|
||
self.writer.add_scalar('epoch/val_accuracy', val_metrics['accuracy'], epoch)
|
||
self.writer.add_scalar('epoch/val_precision', val_metrics['precision'], epoch)
|
||
self.writer.add_scalar('epoch/val_recall', val_metrics['recall'], epoch)
|
||
self.writer.add_scalar('epoch/val_f1', val_metrics['f1'], epoch)
|
||
|
||
# Save checkpoint
|
||
if val_metrics['loss'] < self.best_val_loss:
|
||
self.best_val_loss = val_metrics['loss']
|
||
self.epochs_without_improvement = 0
|
||
self.save_checkpoint(epoch, val_metrics['loss'], is_best=True)
|
||
print(f'✓ New best model saved with val loss: {val_metrics["loss"]:.4f}')
|
||
else:
|
||
self.epochs_without_improvement += 1
|
||
if epoch % self.config.get('save_interval', 5) == 0:
|
||
self.save_checkpoint(epoch, val_metrics['loss'], is_best=False)
|
||
|
||
# Early stopping
|
||
patience = self.config.get('early_stopping_patience', 20)
|
||
if self.epochs_without_improvement >= patience:
|
||
print(f'\n⚠ Early stopping triggered after {patience} epochs without improvement')
|
||
break
|
||
|
||
total_time = time.time() - start_time
|
||
print(f'\n{"="*70}')
|
||
print(f'Training completed in {total_time/60:.2f} minutes')
|
||
print(f'Best validation loss: {self.best_val_loss:.4f}')
|
||
print(f'{"="*70}\n')
|
||
|
||
self.writer.close()
|
||
|
||
return self.history
|
||
|
||
def save_checkpoint(self, epoch: int, val_loss: float, is_best: bool = False):
|
||
"""Сохранение чекпоинта модели"""
|
||
checkpoint_dir = os.path.join(
|
||
self.config.get('output_dir', 'runs/similarity'),
|
||
'checkpoints'
|
||
)
|
||
os.makedirs(checkpoint_dir, exist_ok=True)
|
||
|
||
checkpoint = {
|
||
'epoch': epoch,
|
||
'model_state_dict': self.model.state_dict(),
|
||
'optimizer_state_dict': self.optimizer.state_dict(),
|
||
'val_loss': val_loss,
|
||
'config': self.config,
|
||
'history': self.history
|
||
}
|
||
|
||
checkpoint_path = os.path.join(checkpoint_dir, f'checkpoint_epoch{epoch}.pt')
|
||
torch.save(checkpoint, checkpoint_path)
|
||
|
||
if is_best:
|
||
best_path = os.path.join(checkpoint_dir, 'best_model.pt')
|
||
torch.save(checkpoint, best_path)
|
||
|
||
def load_checkpoint(self, checkpoint_path: str):
|
||
"""Загрузка чекпоинта"""
|
||
checkpoint = torch.load(checkpoint_path, map_location=self.device)
|
||
self.model.load_state_dict(checkpoint['model_state_dict'])
|
||
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
||
if 'history' in checkpoint:
|
||
self.history = checkpoint['history']
|
||
return checkpoint['epoch'], checkpoint['val_loss']
|
||
|
||
|
||
# =============================================================================
|
||
# VISUALIZATION FUNCTIONS
|
||
# =============================================================================
|
||
|
||
def plot_training_history(history: dict, save_path: str = None):
|
||
"""Построение графиков обучения"""
|
||
fig, axes = plt.subplots(2, 3, figsize=(18, 10))
|
||
fig.suptitle('Training History - Siamese Network для корреляции снимков',
|
||
fontsize=16, fontweight='bold')
|
||
|
||
epochs = range(1, len(history['train_loss']) + 1)
|
||
|
||
# Loss
|
||
axes[0, 0].plot(epochs, history['train_loss'], 'b-', label='Train Loss', linewidth=2)
|
||
axes[0, 0].plot(epochs, history['val_loss'], 'r-', label='Val Loss', linewidth=2)
|
||
axes[0, 0].set_xlabel('Epoch')
|
||
axes[0, 0].set_ylabel('Loss')
|
||
axes[0, 0].set_title('Loss Curves')
|
||
axes[0, 0].legend()
|
||
axes[0, 0].grid(True, alpha=0.3)
|
||
|
||
# Accuracy
|
||
axes[0, 1].plot(epochs, history['val_accuracy'], 'g-', linewidth=2)
|
||
axes[0, 1].set_xlabel('Epoch')
|
||
axes[0, 1].set_ylabel('Accuracy')
|
||
axes[0, 1].set_title('Validation Accuracy')
|
||
axes[0, 1].grid(True, alpha=0.3)
|
||
axes[0, 1].set_ylim([0, 1])
|
||
|
||
# F1 Score
|
||
axes[0, 2].plot(epochs, history['val_f1'], 'm-', linewidth=2)
|
||
axes[0, 2].set_xlabel('Epoch')
|
||
axes[0, 2].set_ylabel('F1 Score')
|
||
axes[0, 2].set_title('Validation F1 Score')
|
||
axes[0, 2].grid(True, alpha=0.3)
|
||
axes[0, 2].set_ylim([0, 1])
|
||
|
||
# Precision
|
||
axes[1, 0].plot(epochs, history['val_precision'], 'c-', linewidth=2)
|
||
axes[1, 0].set_xlabel('Epoch')
|
||
axes[1, 0].set_ylabel('Precision')
|
||
axes[1, 0].set_title('Validation Precision')
|
||
axes[1, 0].grid(True, alpha=0.3)
|
||
axes[1, 0].set_ylim([0, 1])
|
||
|
||
# Recall
|
||
axes[1, 1].plot(epochs, history['val_recall'], 'y-', linewidth=2)
|
||
axes[1, 1].set_xlabel('Epoch')
|
||
axes[1, 1].set_ylabel('Recall')
|
||
axes[1, 1].set_title('Validation Recall')
|
||
axes[1, 1].grid(True, alpha=0.3)
|
||
axes[1, 1].set_ylim([0, 1])
|
||
|
||
# Learning Rate
|
||
axes[1, 2].plot(epochs, history['learning_rate'], 'k-', linewidth=2)
|
||
axes[1, 2].set_xlabel('Epoch')
|
||
axes[1, 2].set_ylabel('Learning Rate')
|
||
axes[1, 2].set_title('Learning Rate Schedule')
|
||
axes[1, 2].grid(True, alpha=0.3)
|
||
axes[1, 2].set_yscale('log')
|
||
|
||
plt.tight_layout()
|
||
|
||
if save_path:
|
||
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
||
print(f'Training history plot saved to {save_path}')
|
||
|
||
plt.show()
|
||
|
||
|
||
def plot_roc_curve(predictions: torch.Tensor, targets: torch.Tensor, save_path: str = None):
|
||
"""Построение ROC кривой"""
|
||
from sklearn.metrics import roc_curve, auc
|
||
|
||
predictions_np = predictions.numpy().flatten()
|
||
targets_np = targets.numpy().flatten()
|
||
|
||
fpr, tpr, thresholds = roc_curve(targets_np, predictions_np)
|
||
roc_auc = auc(fpr, tpr)
|
||
|
||
plt.figure(figsize=(10, 8))
|
||
plt.plot(fpr, tpr, color='darkorange', lw=2,
|
||
label=f'ROC curve (AUC = {roc_auc:.3f})')
|
||
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--', label='Random')
|
||
plt.xlim([0.0, 1.0])
|
||
plt.ylim([0.0, 1.05])
|
||
plt.xlabel('False Positive Rate', fontsize=12)
|
||
plt.ylabel('True Positive Rate', fontsize=12)
|
||
plt.title('ROC Curve - Siamese Network', fontsize=14, fontweight='bold')
|
||
plt.legend(loc="lower right", fontsize=12)
|
||
plt.grid(True, alpha=0.3)
|
||
|
||
if save_path:
|
||
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
||
print(f'ROC curve saved to {save_path}')
|
||
|
||
plt.show()
|
||
|
||
return roc_auc
|
||
|
||
|
||
def plot_confusion_matrix(predictions: torch.Tensor, targets: torch.Tensor,
|
||
threshold: float = 0.5, save_path: str = None):
|
||
"""Построение матрицы ошибок"""
|
||
from sklearn.metrics import confusion_matrix
|
||
|
||
predictions_binary = (predictions.numpy().flatten() >= threshold).astype(int)
|
||
targets_binary = (targets.numpy().flatten() >= 0.5).astype(int)
|
||
|
||
cm = confusion_matrix(targets_binary, predictions_binary)
|
||
|
||
plt.figure(figsize=(10, 8))
|
||
plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
|
||
plt.title('Confusion Matrix - Correlation Detection', fontsize=14, fontweight='bold')
|
||
plt.colorbar()
|
||
|
||
classes = ['Different Domains', 'Same Domain']
|
||
tick_marks = np.arange(len(classes))
|
||
plt.xticks(tick_marks, classes, rotation=45, fontsize=12)
|
||
plt.yticks(tick_marks, classes, fontsize=12)
|
||
|
||
# Add text annotations
|
||
thresh = cm.max() / 2.
|
||
for i in range(cm.shape[0]):
|
||
for j in range(cm.shape[1]):
|
||
plt.text(j, i, format(cm[i, j], 'd'),
|
||
ha="center", va="center",
|
||
color="white" if cm[i, j] > thresh else "black",
|
||
fontsize=16, fontweight='bold')
|
||
|
||
plt.ylabel('True Label', fontsize=12)
|
||
plt.xlabel('Predicted Label', fontsize=12)
|
||
plt.tight_layout()
|
||
|
||
if save_path:
|
||
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
||
print(f'Confusion matrix saved to {save_path}')
|
||
|
||
plt.show()
|
||
|
||
|
||
def plot_similarity_distribution(predictions: torch.Tensor, targets: torch.Tensor,
|
||
save_path: str = None):
|
||
"""Распределение предсказанных значений схожести"""
|
||
predictions_np = predictions.numpy().flatten()
|
||
targets_np = targets.numpy().flatten()
|
||
|
||
same_domain = predictions_np[targets_np >= 0.5]
|
||
diff_domain = predictions_np[targets_np < 0.5]
|
||
|
||
plt.figure(figsize=(12, 6))
|
||
|
||
plt.subplot(1, 2, 1)
|
||
plt.hist(same_domain, bins=50, alpha=0.7, color='green', edgecolor='black', label='Same Domain')
|
||
plt.hist(diff_domain, bins=50, alpha=0.7, color='red', edgecolor='black', label='Different Domains')
|
||
plt.xlabel('Predicted Similarity Score', fontsize=12)
|
||
plt.ylabel('Frequency', fontsize=12)
|
||
plt.title('Distribution of Similarity Scores', fontsize=14, fontweight='bold')
|
||
plt.legend(fontsize=11)
|
||
plt.grid(True, alpha=0.3)
|
||
|
||
plt.subplot(1, 2, 2)
|
||
plt.boxplot([diff_domain, same_domain], labels=['Different', 'Same'])
|
||
plt.ylabel('Similarity Score', fontsize=12)
|
||
plt.xlabel('Domain Match', fontsize=12)
|
||
plt.title('Similarity Score by Domain Match', fontsize=14, fontweight='bold')
|
||
plt.grid(True, alpha=0.3, axis='y')
|
||
|
||
plt.tight_layout()
|
||
|
||
if save_path:
|
||
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
||
print(f'Similarity distribution plot saved to {save_path}')
|
||
|
||
plt.show()
|
||
|
||
# Print statistics
|
||
print(f'\n--- Similarity Score Statistics ---')
|
||
print(f'Same Domain:')
|
||
print(f' Mean: {same_domain.mean():.4f}')
|
||
print(f' Std: {same_domain.std():.4f}')
|
||
print(f' Min: {same_domain.min():.4f}')
|
||
print(f' Max: {same_domain.max():.4f}')
|
||
print(f'\nDifferent Domains:')
|
||
print(f' Mean: {diff_domain.mean():.4f}')
|
||
print(f' Std: {diff_domain.std():.4f}')
|
||
print(f' Min: {diff_domain.min():.4f}')
|
||
print(f' Max: {diff_domain.max():.4f}')
|
||
|
||
|
||
def visualize_sample_predictions(model: nn.Module, dataset, device: torch.device,
|
||
num_samples: int = 8, save_path: str = None):
|
||
"""Визуализация примеров предсказаний"""
|
||
model.eval()
|
||
|
||
# Get random samples
|
||
indices = np.random.choice(len(dataset), num_samples, replace=False)
|
||
|
||
fig, axes = plt.subplots(num_samples, 3, figsize=(15, 5*num_samples))
|
||
if num_samples == 1:
|
||
axes = axes.reshape(1, -1)
|
||
|
||
fig.suptitle('Sample Predictions - Siamese Network для корреляции карт',
|
||
fontsize=16, fontweight='bold')
|
||
|
||
with torch.no_grad():
|
||
for idx, sample_idx in enumerate(indices):
|
||
sample = dataset[sample_idx]
|
||
|
||
google_img = sample['google_img'].unsqueeze(0).to(device)
|
||
yandex_img = sample['yandex_img'].unsqueeze(0).to(device)
|
||
true_label = sample['same_domain'].item()
|
||
|
||
# Predict
|
||
pred_similarity = model(google_img, yandex_img).item()
|
||
pred_label = int(pred_similarity >= 0.5)
|
||
|
||
# Denormalize images for visualization
|
||
google_np = google_img.squeeze(0).cpu().numpy().transpose(1, 2, 0)
|
||
yandex_np = yandex_img.squeeze(0).cpu().numpy().transpose(1, 2, 0)
|
||
|
||
# Denormalize (assuming ImageNet normalization)
|
||
mean = np.array([0.485, 0.456, 0.406])
|
||
std = np.array([0.229, 0.224, 0.225])
|
||
google_np = std * google_np + mean
|
||
yandex_np = std * yandex_np + mean
|
||
google_np = np.clip(google_np, 0, 1)
|
||
yandex_np = np.clip(yandex_np, 0, 1)
|
||
|
||
# Plot Google image
|
||
axes[idx, 0].imshow(google_np)
|
||
axes[idx, 0].set_title('Google Map', fontsize=12, fontweight='bold')
|
||
axes[idx, 0].axis('off')
|
||
|
||
# Plot Yandex image
|
||
axes[idx, 1].imshow(yandex_np)
|
||
axes[idx, 1].set_title('Yandex Map', fontsize=12, fontweight='bold')
|
||
axes[idx, 1].axis('off')
|
||
|
||
# Plot prediction info
|
||
axes[idx, 2].axis('off')
|
||
|
||
# Determine color based on correctness
|
||
is_correct = (pred_label == true_label)
|
||
color = 'green' if is_correct else 'red'
|
||
result = '✓ Correct' if is_correct else '✗ Incorrect'
|
||
|
||
info_text = f"""
|
||
Prediction: {pred_similarity:.4f}
|
||
Predicted Label: {'Same' if pred_label == 1 else 'Different'}
|
||
True Label: {'Same' if true_label == 1 else 'Different'}
|
||
|
||
{result}
|
||
"""
|
||
|
||
axes[idx, 2].text(0.5, 0.5, info_text,
|
||
ha='center', va='center',
|
||
fontsize=12,
|
||
bbox=dict(boxstyle='round', facecolor=color, alpha=0.2),
|
||
transform=axes[idx, 2].transAxes)
|
||
|
||
plt.tight_layout()
|
||
|
||
if save_path:
|
||
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
||
print(f'Sample predictions saved to {save_path}')
|
||
|
||
plt.show()
|
||
|
||
|
||
def visualize_feature_space(model: nn.Module, dataloader: DataLoader,
|
||
device: torch.device, max_samples: int = 500,
|
||
save_path: str = None):
|
||
"""Визуализация пространства признаков с помощью t-SNE"""
|
||
from sklearn.manifold import TSNE
|
||
|
||
model.eval()
|
||
|
||
all_features_google = []
|
||
all_features_yandex = []
|
||
all_labels = []
|
||
|
||
with torch.no_grad():
|
||
for i, batch in enumerate(tqdm(dataloader, desc='Extracting features')):
|
||
if i * dataloader.batch_size >= max_samples:
|
||
break
|
||
|
||
google_img = batch['google_img'].to(device)
|
||
yandex_img = batch['yandex_img'].to(device)
|
||
labels = batch['same_domain'].cpu().numpy()
|
||
|
||
# Extract features
|
||
features_google = model.extract_features(google_img).cpu().numpy()
|
||
features_yandex = model.extract_features(yandex_img).cpu().numpy()
|
||
|
||
all_features_google.append(features_google)
|
||
all_features_yandex.append(features_yandex)
|
||
all_labels.append(labels)
|
||
|
||
all_features_google = np.concatenate(all_features_google, axis=0)
|
||
all_features_yandex = np.concatenate(all_features_yandex, axis=0)
|
||
all_labels = np.concatenate(all_labels, axis=0)
|
||
|
||
# Combine features
|
||
all_features = np.concatenate([all_features_google, all_features_yandex], axis=0)
|
||
all_labels = np.concatenate([all_labels, all_labels], axis=0)
|
||
|
||
print(f'\nApplying t-SNE to {all_features.shape[0]} samples...')
|
||
tsne = TSNE(n_components=2, random_state=42, perplexity=30)
|
||
features_2d = tsne.fit_transform(all_features)
|
||
|
||
# Split back into Google and Yandex
|
||
n_samples = len(all_labels) // 2
|
||
features_google_2d = features_2d[:n_samples]
|
||
features_yandex_2d = features_2d[n_samples:]
|
||
labels = all_labels[:n_samples]
|
||
|
||
# Plot
|
||
fig, axes = plt.subplots(1, 2, figsize=(20, 8))
|
||
|
||
# Google features
|
||
for label in [0, 1]:
|
||
mask = labels == label
|
||
axes[0].scatter(
|
||
features_google_2d[mask, 0],
|
||
features_google_2d[mask, 1],
|
||
c='green' if label == 1 else 'red',
|
||
label='Same Domain' if label == 1 else 'Different Domains',
|
||
alpha=0.6,
|
||
s=50
|
||
)
|
||
axes[0].set_title('Google Maps Features (t-SNE)', fontsize=14, fontweight='bold')
|
||
axes[0].set_xlabel('t-SNE Component 1', fontsize=12)
|
||
axes[0].set_ylabel('t-SNE Component 2', fontsize=12)
|
||
axes[0].legend(fontsize=11)
|
||
axes[0].grid(True, alpha=0.3)
|
||
|
||
# Yandex features
|
||
for label in [0, 1]:
|
||
mask = labels == label
|
||
axes[1].scatter(
|
||
features_yandex_2d[mask, 0],
|
||
features_yandex_2d[mask, 1],
|
||
c='green' if label == 1 else 'red',
|
||
label='Same Domain' if label == 1 else 'Different Domains',
|
||
alpha=0.6,
|
||
s=50
|
||
)
|
||
axes[1].set_title('Yandex Maps Features (t-SNE)', fontsize=14, fontweight='bold')
|
||
axes[1].set_xlabel('t-SNE Component 1', fontsize=12)
|
||
axes[1].set_ylabel('t-SNE Component 2', fontsize=12)
|
||
axes[1].legend(fontsize=11)
|
||
axes[1].grid(True, alpha=0.3)
|
||
|
||
plt.tight_layout()
|
||
|
||
if save_path:
|
||
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
||
print(f'Feature space visualization saved to {save_path}')
|
||
|
||
plt.show()
|
||
|
||
|
||
def generate_correlation_heatmap(model: nn.Module, dataloader: DataLoader,
|
||
device: torch.device, num_samples: int = 20,
|
||
save_path: str = None):
|
||
"""Создание тепловой карты корреляций между снимками"""
|
||
model.eval()
|
||
|
||
# Collect samples
|
||
google_images = []
|
||
yandex_images = []
|
||
labels = []
|
||
|
||
with torch.no_grad():
|
||
for i, batch in enumerate(dataloader):
|
||
if len(google_images) >= num_samples:
|
||
break
|
||
|
||
google_img = batch['google_img'].to(device)
|
||
yandex_img = batch['yandex_img'].to(device)
|
||
label = batch['same_domain']
|
||
|
||
google_images.append(google_img[:1])
|
||
yandex_images.append(yandex_img[:1])
|
||
labels.append(label[:1].item())
|
||
|
||
google_images = torch.cat(google_images[:num_samples], dim=0)
|
||
yandex_images = torch.cat(yandex_images[:num_samples], dim=0)
|
||
|
||
# Compute similarity matrix
|
||
similarity_matrix = np.zeros((num_samples, num_samples))
|
||
|
||
with torch.no_grad():
|
||
for i in tqdm(range(num_samples), desc='Computing correlations'):
|
||
for j in range(num_samples):
|
||
google_i = google_images[i:i+1]
|
||
yandex_j = yandex_images[j:j+1]
|
||
similarity = model(google_i, yandex_j).item()
|
||
similarity_matrix[i, j] = similarity
|
||
|
||
# Plot heatmap
|
||
plt.figure(figsize=(14, 12))
|
||
im = plt.imshow(similarity_matrix, cmap='RdYlGn', aspect='auto', vmin=0, vmax=1)
|
||
|
||
plt.colorbar(im, label='Similarity Score', fraction=0.046, pad=0.04)
|
||
plt.title('Correlation Heatmap: Google vs Yandex Maps\n(Матрица корреляций снимков)',
|
||
fontsize=16, fontweight='bold', pad=20)
|
||
plt.xlabel('Yandex Map Index', fontsize=12)
|
||
plt.ylabel('Google Map Index', fontsize=12)
|
||
|
||
# Add grid
|
||
plt.grid(True, which='both', color='gray', linestyle='-', linewidth=0.5, alpha=0.3)
|
||
|
||
# Add text annotations for diagonal (same pairs)
|
||
for i in range(min(num_samples, 10)): # Annotate first 10 for readability
|
||
if labels[i] == 1: # True match
|
||
plt.text(i, i, '✓', ha='center', va='center',
|
||
color='white', fontsize=12, fontweight='bold')
|
||
|
||
plt.tight_layout()
|
||
|
||
if save_path:
|
||
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
||
print(f'Correlation heatmap saved to {save_path}')
|
||
|
||
plt.show()
|
||
|
||
# Print statistics
|
||
diagonal = np.diag(similarity_matrix)
|
||
off_diagonal = similarity_matrix[~np.eye(num_samples, dtype=bool)]
|
||
|
||
print(f'\n--- Correlation Statistics ---')
|
||
print(f'Diagonal (matched pairs):')
|
||
print(f' Mean: {diagonal.mean():.4f}')
|
||
print(f' Std: {diagonal.std():.4f}')
|
||
print(f'\nOff-diagonal (mismatched pairs):')
|
||
print(f' Mean: {off_diagonal.mean():.4f}')
|
||
print(f' Std: {off_diagonal.std():.4f}')
|
||
|
||
|
||
# =============================================================================
|
||
# MAIN TRAINING SCRIPT
|
||
# =============================================================================
|
||
|
||
def main():
|
||
"""Основная функция обучения"""
|
||
|
||
# Configuration
|
||
config_dict = config.copy()
|
||
if isinstance(config_dict.get('image_size'), list):
|
||
config_dict['image_size'] = tuple(config_dict['image_size'])
|
||
|
||
# Device
|
||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||
print(f'\n{"="*70}')
|
||
print(f'Siamese Network Training for Map Correlation')
|
||
print(f'Обучение сиамской сети для корреляции снимков')
|
||
print(f'{"="*70}')
|
||
print(f'Using device: {device}')
|
||
if torch.cuda.is_available():
|
||
print(f'GPU: {torch.cuda.get_device_name(0)}')
|
||
print(f'GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB')
|
||
|
||
# Create data loaders
|
||
print(f'\n{"="*70}')
|
||
print('Creating data loaders...')
|
||
print(f'{"="*70}')
|
||
|
||
train_loader, val_loader = create_data_loaders(
|
||
root_dir=config_dict['data_dir'],
|
||
batch_size=config_dict['batch_size'],
|
||
train_split=config_dict['train_split'],
|
||
num_workers=config_dict['num_workers'],
|
||
image_size=config_dict['image_size'],
|
||
augment_train=True,
|
||
augment_val=False,
|
||
device=device
|
||
)
|
||
|
||
print(f'Train batches: {len(train_loader)}')
|
||
print(f'Val batches: {len(val_loader)}')
|
||
print(f'Train samples: {len(train_loader.dataset)}')
|
||
print(f'Val samples: {len(val_loader.dataset)}')
|
||
|
||
# Create model
|
||
print(f'\n{"="*70}')
|
||
print('Creating model...')
|
||
print(f'{"="*70}')
|
||
|
||
model = create_similarity_model(
|
||
model_type='backbone',
|
||
input_size=config_dict['image_size'][0] if isinstance(config_dict['image_size'], (tuple, list)) else config_dict['image_size'],
|
||
input_channels=3,
|
||
backbone_name='resnet18',
|
||
pretrained=True,
|
||
dropout_rate=0.3,
|
||
use_batch_norm=True
|
||
)
|
||
|
||
total_params = sum(p.numel() for p in model.parameters())
|
||
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||
print(f'Total parameters: {total_params:,}')
|
||
print(f'Trainable parameters: {trainable_params:,}')
|
||
|
||
# Create trainer
|
||
trainer = SimilarityTrainer(
|
||
model=model,
|
||
trainloader=train_loader,
|
||
valloader=val_loader,
|
||
device=device,
|
||
config=config_dict
|
||
)
|
||
|
||
# Train model
|
||
print(f'\n{"="*70}')
|
||
print('Starting training...')
|
||
print(f'{"="*70}')
|
||
|
||
history = trainer.train(config_dict['epochs'])
|
||
|
||
# =============================================================================
|
||
# VISUALIZATION AND RESULTS
|
||
# =============================================================================
|
||
|
||
print(f'\n{"="*70}')
|
||
print('Generating visualizations...')
|
||
print(f'{"="*70}')
|
||
|
||
output_dir = config_dict.get('output_dir', 'runs/similarity')
|
||
vis_dir = os.path.join(output_dir, 'visualizations')
|
||
os.makedirs(vis_dir, exist_ok=True)
|
||
|
||
# 1. Training history
|
||
print('\n1. Plotting training history...')
|
||
plot_training_history(
|
||
history,
|
||
save_path=os.path.join(vis_dir, 'training_history.png')
|
||
)
|
||
|
||
# 2. Validation metrics
|
||
print('\n2. Computing validation predictions...')
|
||
trainer.model.eval()
|
||
val_predictions = []
|
||
val_targets = []
|
||
|
||
with torch.no_grad():
|
||
for batch in tqdm(val_loader, desc='Validation'):
|
||
google_img = batch['google_img'].to(device)
|
||
yandex_img = batch['yandex_img'].to(device)
|
||
target = batch['same_domain'].float().unsqueeze(1)
|
||
|
||
output = trainer.model(google_img, yandex_img)
|
||
val_predictions.append(output.cpu())
|
||
val_targets.append(target.cpu())
|
||
|
||
val_predictions = torch.cat(val_predictions, dim=0)
|
||
val_targets = torch.cat(val_targets, dim=0)
|
||
|
||
# 3. ROC curve
|
||
print('\n3. Plotting ROC curve...')
|
||
roc_auc = plot_roc_curve(
|
||
val_predictions,
|
||
val_targets,
|
||
save_path=os.path.join(vis_dir, 'roc_curve.png')
|
||
)
|
||
print(f'ROC AUC Score: {roc_auc:.4f}')
|
||
|
||
# 4. Confusion matrix
|
||
print('\n4. Plotting confusion matrix...')
|
||
plot_confusion_matrix(
|
||
val_predictions,
|
||
val_targets,
|
||
threshold=0.5,
|
||
save_path=os.path.join(vis_dir, 'confusion_matrix.png')
|
||
)
|
||
|
||
# 5. Similarity distribution
|
||
print('\n5. Plotting similarity distribution...')
|
||
plot_similarity_distribution(
|
||
val_predictions,
|
||
val_targets,
|
||
save_path=os.path.join(vis_dir, 'similarity_distribution.png')
|
||
)
|
||
|
||
# 6. Sample predictions
|
||
print('\n6. Visualizing sample predictions...')
|
||
visualize_sample_predictions(
|
||
trainer.model,
|
||
val_loader.dataset,
|
||
device,
|
||
num_samples=8,
|
||
save_path=os.path.join(vis_dir, 'sample_predictions.png')
|
||
)
|
||
|
||
# 7. Feature space visualization
|
||
print('\n7. Visualizing feature space (t-SNE)...')
|
||
visualize_feature_space(
|
||
trainer.model,
|
||
val_loader,
|
||
device,
|
||
max_samples=500,
|
||
save_path=os.path.join(vis_dir, 'feature_space_tsne.png')
|
||
)
|
||
|
||
# 8. Correlation heatmap
|
||
print('\n8. Generating correlation heatmap...')
|
||
generate_correlation_heatmap(
|
||
trainer.model,
|
||
val_loader,
|
||
device,
|
||
num_samples=20,
|
||
save_path=os.path.join(vis_dir, 'correlation_heatmap.png')
|
||
)
|
||
|
||
# =============================================================================
|
||
# FINAL RESULTS SUMMARY
|
||
# =============================================================================
|
||
|
||
print(f'\n{"="*70}')
|
||
print('FINAL RESULTS SUMMARY')
|
||
print('ИТОГОВЫЕ РЕЗУЛЬТАТЫ')
|
||
print(f'{"="*70}')
|
||
|
||
print(f'\nBest Validation Loss: {trainer.best_val_loss:.4f}')
|
||
print(f'Final Validation Accuracy: {history["val_accuracy"][-1]:.4f}')
|
||
print(f'Final Validation F1 Score: {history["val_f1"][-1]:.4f}')
|
||
print(f'Final Validation Precision: {history["val_precision"][-1]:.4f}')
|
||
print(f'Final Validation Recall: {history["val_recall"][-1]:.4f}')
|
||
print(f'ROC AUC Score: {roc_auc:.4f}')
|
||
|
||
print(f'\nCheckpoints saved to: {os.path.join(output_dir, "checkpoints")}')
|
||
print(f'Visualizations saved to: {vis_dir}')
|
||
|
||
print(f'\n{"="*70}')
|
||
print('Training and visualization completed successfully!')
|
||
print('Обучение и визуализация завершены успешно!')
|
||
print(f'{"="*70}\n')
|
||
|
||
|
||
if __name__ == '__main__':
|
||
main() |