feat: complete sian-similarity

This commit is contained in:
2026-03-22 14:29:00 +03:00
parent 43cd4222bc
commit 05f8746d58
8 changed files with 3780 additions and 1903 deletions

View File

@@ -0,0 +1,917 @@
import os
import time
from datetime import datetime
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
# =============================================================================
# TRAINING LOOP WITH VISUALIZATION
# =============================================================================
class SimilarityTrainer:
def __init__(
self,
model: nn.Module,
trainloader: DataLoader,
valloader: DataLoader,
device: torch.device,
config: dict,
):
self.model = model.to(device)
self.trainloader = trainloader
self.valloader = valloader
self.device = device
self.config = config
self.criterion = SimilarityLoss()
self.optimizer = optim.Adam(
model.parameters(),
lr=config.get('learning_rate', 2e-4),
betas=(config.get('beta1', 0.5), config.get('beta2', 0.999))
)
self.writer = None
self.best_val_loss = float('inf')
self.epochs_without_improvement = 0
# Для хранения истории метрик
self.history = {
'train_loss': [],
'val_loss': [],
'val_accuracy': [],
'val_precision': [],
'val_recall': [],
'val_f1': [],
'learning_rate': []
}
def train_epoch(self, epoch: int) -> dict:
"""Обучение на одной эпохе"""
self.model.train()
total_loss = 0
total_samples = 0
all_metrics = []
pbar = tqdm(self.trainloader, desc=f'Epoch {epoch}')
for batch_idx, batch in enumerate(pbar):
google_img = batch['google_img'].to(self.device)
yandex_img = batch['yandex_img'].to(self.device)
target = batch['same_domain'].float().to(self.device).unsqueeze(1)
self.optimizer.zero_grad()
# Forward pass
output = self.model(google_img, yandex_img)
loss = self.criterion(output, target)
# Backward pass
loss.backward()
# Gradient clipping
if self.config.get('grad_clip', None):
torch.nn.utils.clip_grad_norm_(
self.model.parameters(),
self.config['grad_clip']
)
self.optimizer.step()
total_loss += loss.item() * google_img.size(0)
total_samples += google_img.size(0)
# Compute metrics
if batch_idx % self.config.get('log_interval', 10) == 0:
metrics = self.criterion.compute_metrics(output, target)
all_metrics.append(metrics)
pbar.set_postfix({
'loss': f"{loss.item():.4f}",
'acc': f"{metrics['accuracy']:.4f}"
})
if self.writer:
global_step = epoch * len(self.trainloader) + batch_idx
self.writer.add_scalar('train/loss', loss.item(), global_step)
self.writer.add_scalar('train/accuracy', metrics['accuracy'], global_step)
avg_loss = total_loss / total_samples
# Average metrics
if all_metrics:
avg_metrics = {
key: sum(m[key] for m in all_metrics) / len(all_metrics)
for key in all_metrics[0].keys()
}
else:
avg_metrics = {}
return {'loss': avg_loss, **avg_metrics}
def validate(self) -> dict:
"""Валидация модели"""
self.model.eval()
total_loss = 0
total_samples = 0
all_metrics = []
# Для ROC и confusion matrix
all_predictions = []
all_targets = []
with torch.no_grad():
for batch in tqdm(self.valloader, desc='Validation'):
google_img = batch['google_img'].to(self.device)
yandex_img = batch['yandex_img'].to(self.device)
target = batch['same_domain'].float().to(self.device).unsqueeze(1)
output = self.model(google_img, yandex_img)
loss = self.criterion(output, target)
total_loss += loss.item() * google_img.size(0)
total_samples += google_img.size(0)
metrics = self.criterion.compute_metrics(output, target)
all_metrics.append(metrics)
all_predictions.append(output.cpu())
all_targets.append(target.cpu())
avg_loss = total_loss / total_samples
avg_metrics = {
key: sum(m[key] for m in all_metrics) / len(all_metrics)
for key in all_metrics[0].keys()
}
# Concatenate all predictions and targets
all_predictions = torch.cat(all_predictions, dim=0)
all_targets = torch.cat(all_targets, dim=0)
return {
'loss': avg_loss,
**avg_metrics,
'predictions': all_predictions,
'targets': all_targets
}
def train(self, num_epochs: int):
"""Основной цикл обучения"""
log_dir = os.path.join(self.config.get('output_dir', 'runs/similarity'))
os.makedirs(log_dir, exist_ok=True)
self.writer = SummaryWriter(log_dir)
print(f'\n{"="*70}')
print(f'Starting training for {num_epochs} epochs')
print(f'Logging to {log_dir}')
print(f'{"="*70}\n')
start_time = time.time()
for epoch in range(1, num_epochs + 1):
epoch_start = time.time()
print(f'\n--- Epoch {epoch}/{num_epochs} ---')
# Train
train_metrics = self.train_epoch(epoch)
# Validate
val_metrics = self.validate()
# Store history
self.history['train_loss'].append(train_metrics['loss'])
self.history['val_loss'].append(val_metrics['loss'])
self.history['val_accuracy'].append(val_metrics['accuracy'])
self.history['val_precision'].append(val_metrics['precision'])
self.history['val_recall'].append(val_metrics['recall'])
self.history['val_f1'].append(val_metrics['f1'])
self.history['learning_rate'].append(
self.optimizer.param_groups[0]['lr']
)
# Print metrics
print(f'\nTrain Loss: {train_metrics["loss"]:.4f}')
print(f'Val Loss: {val_metrics["loss"]:.4f}')
print(f'Val Accuracy: {val_metrics["accuracy"]:.4f}')
print(f'Val Precision: {val_metrics["precision"]:.4f}')
print(f'Val Recall: {val_metrics["recall"]:.4f}')
print(f'Val F1: {val_metrics["f1"]:.4f}')
epoch_time = time.time() - epoch_start
print(f'Epoch time: {epoch_time:.2f}s')
# TensorBoard logging
if self.writer:
self.writer.add_scalar('epoch/train_loss', train_metrics['loss'], epoch)
self.writer.add_scalar('epoch/val_loss', val_metrics['loss'], epoch)
self.writer.add_scalar('epoch/val_accuracy', val_metrics['accuracy'], epoch)
self.writer.add_scalar('epoch/val_precision', val_metrics['precision'], epoch)
self.writer.add_scalar('epoch/val_recall', val_metrics['recall'], epoch)
self.writer.add_scalar('epoch/val_f1', val_metrics['f1'], epoch)
# Save checkpoint
if val_metrics['loss'] < self.best_val_loss:
self.best_val_loss = val_metrics['loss']
self.epochs_without_improvement = 0
self.save_checkpoint(epoch, val_metrics['loss'], is_best=True)
print(f'✓ New best model saved with val loss: {val_metrics["loss"]:.4f}')
else:
self.epochs_without_improvement += 1
if epoch % self.config.get('save_interval', 5) == 0:
self.save_checkpoint(epoch, val_metrics['loss'], is_best=False)
# Early stopping
patience = self.config.get('early_stopping_patience', 20)
if self.epochs_without_improvement >= patience:
print(f'\n⚠ Early stopping triggered after {patience} epochs without improvement')
break
total_time = time.time() - start_time
print(f'\n{"="*70}')
print(f'Training completed in {total_time/60:.2f} minutes')
print(f'Best validation loss: {self.best_val_loss:.4f}')
print(f'{"="*70}\n')
self.writer.close()
return self.history
def save_checkpoint(self, epoch: int, val_loss: float, is_best: bool = False):
"""Сохранение чекпоинта модели"""
checkpoint_dir = os.path.join(
self.config.get('output_dir', 'runs/similarity'),
'checkpoints'
)
os.makedirs(checkpoint_dir, exist_ok=True)
checkpoint = {
'epoch': epoch,
'model_state_dict': self.model.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'val_loss': val_loss,
'config': self.config,
'history': self.history
}
checkpoint_path = os.path.join(checkpoint_dir, f'checkpoint_epoch{epoch}.pt')
torch.save(checkpoint, checkpoint_path)
if is_best:
best_path = os.path.join(checkpoint_dir, 'best_model.pt')
torch.save(checkpoint, best_path)
def load_checkpoint(self, checkpoint_path: str):
"""Загрузка чекпоинта"""
checkpoint = torch.load(checkpoint_path, map_location=self.device)
self.model.load_state_dict(checkpoint['model_state_dict'])
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
if 'history' in checkpoint:
self.history = checkpoint['history']
return checkpoint['epoch'], checkpoint['val_loss']
# =============================================================================
# VISUALIZATION FUNCTIONS
# =============================================================================
def plot_training_history(history: dict, save_path: str = None):
"""Построение графиков обучения"""
fig, axes = plt.subplots(2, 3, figsize=(18, 10))
fig.suptitle('Training History - Siamese Network для корреляции снимков',
fontsize=16, fontweight='bold')
epochs = range(1, len(history['train_loss']) + 1)
# Loss
axes[0, 0].plot(epochs, history['train_loss'], 'b-', label='Train Loss', linewidth=2)
axes[0, 0].plot(epochs, history['val_loss'], 'r-', label='Val Loss', linewidth=2)
axes[0, 0].set_xlabel('Epoch')
axes[0, 0].set_ylabel('Loss')
axes[0, 0].set_title('Loss Curves')
axes[0, 0].legend()
axes[0, 0].grid(True, alpha=0.3)
# Accuracy
axes[0, 1].plot(epochs, history['val_accuracy'], 'g-', linewidth=2)
axes[0, 1].set_xlabel('Epoch')
axes[0, 1].set_ylabel('Accuracy')
axes[0, 1].set_title('Validation Accuracy')
axes[0, 1].grid(True, alpha=0.3)
axes[0, 1].set_ylim([0, 1])
# F1 Score
axes[0, 2].plot(epochs, history['val_f1'], 'm-', linewidth=2)
axes[0, 2].set_xlabel('Epoch')
axes[0, 2].set_ylabel('F1 Score')
axes[0, 2].set_title('Validation F1 Score')
axes[0, 2].grid(True, alpha=0.3)
axes[0, 2].set_ylim([0, 1])
# Precision
axes[1, 0].plot(epochs, history['val_precision'], 'c-', linewidth=2)
axes[1, 0].set_xlabel('Epoch')
axes[1, 0].set_ylabel('Precision')
axes[1, 0].set_title('Validation Precision')
axes[1, 0].grid(True, alpha=0.3)
axes[1, 0].set_ylim([0, 1])
# Recall
axes[1, 1].plot(epochs, history['val_recall'], 'y-', linewidth=2)
axes[1, 1].set_xlabel('Epoch')
axes[1, 1].set_ylabel('Recall')
axes[1, 1].set_title('Validation Recall')
axes[1, 1].grid(True, alpha=0.3)
axes[1, 1].set_ylim([0, 1])
# Learning Rate
axes[1, 2].plot(epochs, history['learning_rate'], 'k-', linewidth=2)
axes[1, 2].set_xlabel('Epoch')
axes[1, 2].set_ylabel('Learning Rate')
axes[1, 2].set_title('Learning Rate Schedule')
axes[1, 2].grid(True, alpha=0.3)
axes[1, 2].set_yscale('log')
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches='tight')
print(f'Training history plot saved to {save_path}')
plt.show()
def plot_roc_curve(predictions: torch.Tensor, targets: torch.Tensor, save_path: str = None):
"""Построение ROC кривой"""
from sklearn.metrics import roc_curve, auc
predictions_np = predictions.numpy().flatten()
targets_np = targets.numpy().flatten()
fpr, tpr, thresholds = roc_curve(targets_np, predictions_np)
roc_auc = auc(fpr, tpr)
plt.figure(figsize=(10, 8))
plt.plot(fpr, tpr, color='darkorange', lw=2,
label=f'ROC curve (AUC = {roc_auc:.3f})')
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--', label='Random')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate', fontsize=12)
plt.ylabel('True Positive Rate', fontsize=12)
plt.title('ROC Curve - Siamese Network', fontsize=14, fontweight='bold')
plt.legend(loc="lower right", fontsize=12)
plt.grid(True, alpha=0.3)
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches='tight')
print(f'ROC curve saved to {save_path}')
plt.show()
return roc_auc
def plot_confusion_matrix(predictions: torch.Tensor, targets: torch.Tensor,
threshold: float = 0.5, save_path: str = None):
"""Построение матрицы ошибок"""
from sklearn.metrics import confusion_matrix
predictions_binary = (predictions.numpy().flatten() >= threshold).astype(int)
targets_binary = (targets.numpy().flatten() >= 0.5).astype(int)
cm = confusion_matrix(targets_binary, predictions_binary)
plt.figure(figsize=(10, 8))
plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
plt.title('Confusion Matrix - Correlation Detection', fontsize=14, fontweight='bold')
plt.colorbar()
classes = ['Different Domains', 'Same Domain']
tick_marks = np.arange(len(classes))
plt.xticks(tick_marks, classes, rotation=45, fontsize=12)
plt.yticks(tick_marks, classes, fontsize=12)
# Add text annotations
thresh = cm.max() / 2.
for i in range(cm.shape[0]):
for j in range(cm.shape[1]):
plt.text(j, i, format(cm[i, j], 'd'),
ha="center", va="center",
color="white" if cm[i, j] > thresh else "black",
fontsize=16, fontweight='bold')
plt.ylabel('True Label', fontsize=12)
plt.xlabel('Predicted Label', fontsize=12)
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches='tight')
print(f'Confusion matrix saved to {save_path}')
plt.show()
def plot_similarity_distribution(predictions: torch.Tensor, targets: torch.Tensor,
save_path: str = None):
"""Распределение предсказанных значений схожести"""
predictions_np = predictions.numpy().flatten()
targets_np = targets.numpy().flatten()
same_domain = predictions_np[targets_np >= 0.5]
diff_domain = predictions_np[targets_np < 0.5]
plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.hist(same_domain, bins=50, alpha=0.7, color='green', edgecolor='black', label='Same Domain')
plt.hist(diff_domain, bins=50, alpha=0.7, color='red', edgecolor='black', label='Different Domains')
plt.xlabel('Predicted Similarity Score', fontsize=12)
plt.ylabel('Frequency', fontsize=12)
plt.title('Distribution of Similarity Scores', fontsize=14, fontweight='bold')
plt.legend(fontsize=11)
plt.grid(True, alpha=0.3)
plt.subplot(1, 2, 2)
plt.boxplot([diff_domain, same_domain], labels=['Different', 'Same'])
plt.ylabel('Similarity Score', fontsize=12)
plt.xlabel('Domain Match', fontsize=12)
plt.title('Similarity Score by Domain Match', fontsize=14, fontweight='bold')
plt.grid(True, alpha=0.3, axis='y')
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches='tight')
print(f'Similarity distribution plot saved to {save_path}')
plt.show()
# Print statistics
print(f'\n--- Similarity Score Statistics ---')
print(f'Same Domain:')
print(f' Mean: {same_domain.mean():.4f}')
print(f' Std: {same_domain.std():.4f}')
print(f' Min: {same_domain.min():.4f}')
print(f' Max: {same_domain.max():.4f}')
print(f'\nDifferent Domains:')
print(f' Mean: {diff_domain.mean():.4f}')
print(f' Std: {diff_domain.std():.4f}')
print(f' Min: {diff_domain.min():.4f}')
print(f' Max: {diff_domain.max():.4f}')
def visualize_sample_predictions(model: nn.Module, dataset, device: torch.device,
num_samples: int = 8, save_path: str = None):
"""Визуализация примеров предсказаний"""
model.eval()
# Get random samples
indices = np.random.choice(len(dataset), num_samples, replace=False)
fig, axes = plt.subplots(num_samples, 3, figsize=(15, 5*num_samples))
if num_samples == 1:
axes = axes.reshape(1, -1)
fig.suptitle('Sample Predictions - Siamese Network для корреляции карт',
fontsize=16, fontweight='bold')
with torch.no_grad():
for idx, sample_idx in enumerate(indices):
sample = dataset[sample_idx]
google_img = sample['google_img'].unsqueeze(0).to(device)
yandex_img = sample['yandex_img'].unsqueeze(0).to(device)
true_label = sample['same_domain'].item()
# Predict
pred_similarity = model(google_img, yandex_img).item()
pred_label = int(pred_similarity >= 0.5)
# Denormalize images for visualization
google_np = google_img.squeeze(0).cpu().numpy().transpose(1, 2, 0)
yandex_np = yandex_img.squeeze(0).cpu().numpy().transpose(1, 2, 0)
# Denormalize (assuming ImageNet normalization)
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
google_np = std * google_np + mean
yandex_np = std * yandex_np + mean
google_np = np.clip(google_np, 0, 1)
yandex_np = np.clip(yandex_np, 0, 1)
# Plot Google image
axes[idx, 0].imshow(google_np)
axes[idx, 0].set_title('Google Map', fontsize=12, fontweight='bold')
axes[idx, 0].axis('off')
# Plot Yandex image
axes[idx, 1].imshow(yandex_np)
axes[idx, 1].set_title('Yandex Map', fontsize=12, fontweight='bold')
axes[idx, 1].axis('off')
# Plot prediction info
axes[idx, 2].axis('off')
# Determine color based on correctness
is_correct = (pred_label == true_label)
color = 'green' if is_correct else 'red'
result = '✓ Correct' if is_correct else '✗ Incorrect'
info_text = f"""
Prediction: {pred_similarity:.4f}
Predicted Label: {'Same' if pred_label == 1 else 'Different'}
True Label: {'Same' if true_label == 1 else 'Different'}
{result}
"""
axes[idx, 2].text(0.5, 0.5, info_text,
ha='center', va='center',
fontsize=12,
bbox=dict(boxstyle='round', facecolor=color, alpha=0.2),
transform=axes[idx, 2].transAxes)
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches='tight')
print(f'Sample predictions saved to {save_path}')
plt.show()
def visualize_feature_space(model: nn.Module, dataloader: DataLoader,
device: torch.device, max_samples: int = 500,
save_path: str = None):
"""Визуализация пространства признаков с помощью t-SNE"""
from sklearn.manifold import TSNE
model.eval()
all_features_google = []
all_features_yandex = []
all_labels = []
with torch.no_grad():
for i, batch in enumerate(tqdm(dataloader, desc='Extracting features')):
if i * dataloader.batch_size >= max_samples:
break
google_img = batch['google_img'].to(device)
yandex_img = batch['yandex_img'].to(device)
labels = batch['same_domain'].cpu().numpy()
# Extract features
features_google = model.extract_features(google_img).cpu().numpy()
features_yandex = model.extract_features(yandex_img).cpu().numpy()
all_features_google.append(features_google)
all_features_yandex.append(features_yandex)
all_labels.append(labels)
all_features_google = np.concatenate(all_features_google, axis=0)
all_features_yandex = np.concatenate(all_features_yandex, axis=0)
all_labels = np.concatenate(all_labels, axis=0)
# Combine features
all_features = np.concatenate([all_features_google, all_features_yandex], axis=0)
all_labels = np.concatenate([all_labels, all_labels], axis=0)
print(f'\nApplying t-SNE to {all_features.shape[0]} samples...')
tsne = TSNE(n_components=2, random_state=42, perplexity=30)
features_2d = tsne.fit_transform(all_features)
# Split back into Google and Yandex
n_samples = len(all_labels) // 2
features_google_2d = features_2d[:n_samples]
features_yandex_2d = features_2d[n_samples:]
labels = all_labels[:n_samples]
# Plot
fig, axes = plt.subplots(1, 2, figsize=(20, 8))
# Google features
for label in [0, 1]:
mask = labels == label
axes[0].scatter(
features_google_2d[mask, 0],
features_google_2d[mask, 1],
c='green' if label == 1 else 'red',
label='Same Domain' if label == 1 else 'Different Domains',
alpha=0.6,
s=50
)
axes[0].set_title('Google Maps Features (t-SNE)', fontsize=14, fontweight='bold')
axes[0].set_xlabel('t-SNE Component 1', fontsize=12)
axes[0].set_ylabel('t-SNE Component 2', fontsize=12)
axes[0].legend(fontsize=11)
axes[0].grid(True, alpha=0.3)
# Yandex features
for label in [0, 1]:
mask = labels == label
axes[1].scatter(
features_yandex_2d[mask, 0],
features_yandex_2d[mask, 1],
c='green' if label == 1 else 'red',
label='Same Domain' if label == 1 else 'Different Domains',
alpha=0.6,
s=50
)
axes[1].set_title('Yandex Maps Features (t-SNE)', fontsize=14, fontweight='bold')
axes[1].set_xlabel('t-SNE Component 1', fontsize=12)
axes[1].set_ylabel('t-SNE Component 2', fontsize=12)
axes[1].legend(fontsize=11)
axes[1].grid(True, alpha=0.3)
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches='tight')
print(f'Feature space visualization saved to {save_path}')
plt.show()
def generate_correlation_heatmap(model: nn.Module, dataloader: DataLoader,
device: torch.device, num_samples: int = 20,
save_path: str = None):
"""Создание тепловой карты корреляций между снимками"""
model.eval()
# Collect samples
google_images = []
yandex_images = []
labels = []
with torch.no_grad():
for i, batch in enumerate(dataloader):
if len(google_images) >= num_samples:
break
google_img = batch['google_img'].to(device)
yandex_img = batch['yandex_img'].to(device)
label = batch['same_domain']
google_images.append(google_img[:1])
yandex_images.append(yandex_img[:1])
labels.append(label[:1].item())
google_images = torch.cat(google_images[:num_samples], dim=0)
yandex_images = torch.cat(yandex_images[:num_samples], dim=0)
# Compute similarity matrix
similarity_matrix = np.zeros((num_samples, num_samples))
with torch.no_grad():
for i in tqdm(range(num_samples), desc='Computing correlations'):
for j in range(num_samples):
google_i = google_images[i:i+1]
yandex_j = yandex_images[j:j+1]
similarity = model(google_i, yandex_j).item()
similarity_matrix[i, j] = similarity
# Plot heatmap
plt.figure(figsize=(14, 12))
im = plt.imshow(similarity_matrix, cmap='RdYlGn', aspect='auto', vmin=0, vmax=1)
plt.colorbar(im, label='Similarity Score', fraction=0.046, pad=0.04)
plt.title('Correlation Heatmap: Google vs Yandex Maps\n(Матрица корреляций снимков)',
fontsize=16, fontweight='bold', pad=20)
plt.xlabel('Yandex Map Index', fontsize=12)
plt.ylabel('Google Map Index', fontsize=12)
# Add grid
plt.grid(True, which='both', color='gray', linestyle='-', linewidth=0.5, alpha=0.3)
# Add text annotations for diagonal (same pairs)
for i in range(min(num_samples, 10)): # Annotate first 10 for readability
if labels[i] == 1: # True match
plt.text(i, i, '', ha='center', va='center',
color='white', fontsize=12, fontweight='bold')
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches='tight')
print(f'Correlation heatmap saved to {save_path}')
plt.show()
# Print statistics
diagonal = np.diag(similarity_matrix)
off_diagonal = similarity_matrix[~np.eye(num_samples, dtype=bool)]
print(f'\n--- Correlation Statistics ---')
print(f'Diagonal (matched pairs):')
print(f' Mean: {diagonal.mean():.4f}')
print(f' Std: {diagonal.std():.4f}')
print(f'\nOff-diagonal (mismatched pairs):')
print(f' Mean: {off_diagonal.mean():.4f}')
print(f' Std: {off_diagonal.std():.4f}')
# =============================================================================
# MAIN TRAINING SCRIPT
# =============================================================================
def main():
"""Основная функция обучения"""
# Configuration
config_dict = config.copy()
if isinstance(config_dict.get('image_size'), list):
config_dict['image_size'] = tuple(config_dict['image_size'])
# Device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'\n{"="*70}')
print(f'Siamese Network Training for Map Correlation')
print(f'Обучение сиамской сети для корреляции снимков')
print(f'{"="*70}')
print(f'Using device: {device}')
if torch.cuda.is_available():
print(f'GPU: {torch.cuda.get_device_name(0)}')
print(f'GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB')
# Create data loaders
print(f'\n{"="*70}')
print('Creating data loaders...')
print(f'{"="*70}')
train_loader, val_loader = create_data_loaders(
root_dir=config_dict['data_dir'],
batch_size=config_dict['batch_size'],
train_split=config_dict['train_split'],
num_workers=config_dict['num_workers'],
image_size=config_dict['image_size'],
augment_train=True,
augment_val=False,
device=device
)
print(f'Train batches: {len(train_loader)}')
print(f'Val batches: {len(val_loader)}')
print(f'Train samples: {len(train_loader.dataset)}')
print(f'Val samples: {len(val_loader.dataset)}')
# Create model
print(f'\n{"="*70}')
print('Creating model...')
print(f'{"="*70}')
model = create_similarity_model(
model_type='backbone',
input_size=config_dict['image_size'][0] if isinstance(config_dict['image_size'], (tuple, list)) else config_dict['image_size'],
input_channels=3,
backbone_name='resnet18',
pretrained=True,
dropout_rate=0.3,
use_batch_norm=True
)
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'Total parameters: {total_params:,}')
print(f'Trainable parameters: {trainable_params:,}')
# Create trainer
trainer = SimilarityTrainer(
model=model,
trainloader=train_loader,
valloader=val_loader,
device=device,
config=config_dict
)
# Train model
print(f'\n{"="*70}')
print('Starting training...')
print(f'{"="*70}')
history = trainer.train(config_dict['epochs'])
# =============================================================================
# VISUALIZATION AND RESULTS
# =============================================================================
print(f'\n{"="*70}')
print('Generating visualizations...')
print(f'{"="*70}')
output_dir = config_dict.get('output_dir', 'runs/similarity')
vis_dir = os.path.join(output_dir, 'visualizations')
os.makedirs(vis_dir, exist_ok=True)
# 1. Training history
print('\n1. Plotting training history...')
plot_training_history(
history,
save_path=os.path.join(vis_dir, 'training_history.png')
)
# 2. Validation metrics
print('\n2. Computing validation predictions...')
trainer.model.eval()
val_predictions = []
val_targets = []
with torch.no_grad():
for batch in tqdm(val_loader, desc='Validation'):
google_img = batch['google_img'].to(device)
yandex_img = batch['yandex_img'].to(device)
target = batch['same_domain'].float().unsqueeze(1)
output = trainer.model(google_img, yandex_img)
val_predictions.append(output.cpu())
val_targets.append(target.cpu())
val_predictions = torch.cat(val_predictions, dim=0)
val_targets = torch.cat(val_targets, dim=0)
# 3. ROC curve
print('\n3. Plotting ROC curve...')
roc_auc = plot_roc_curve(
val_predictions,
val_targets,
save_path=os.path.join(vis_dir, 'roc_curve.png')
)
print(f'ROC AUC Score: {roc_auc:.4f}')
# 4. Confusion matrix
print('\n4. Plotting confusion matrix...')
plot_confusion_matrix(
val_predictions,
val_targets,
threshold=0.5,
save_path=os.path.join(vis_dir, 'confusion_matrix.png')
)
# 5. Similarity distribution
print('\n5. Plotting similarity distribution...')
plot_similarity_distribution(
val_predictions,
val_targets,
save_path=os.path.join(vis_dir, 'similarity_distribution.png')
)
# 6. Sample predictions
print('\n6. Visualizing sample predictions...')
visualize_sample_predictions(
trainer.model,
val_loader.dataset,
device,
num_samples=8,
save_path=os.path.join(vis_dir, 'sample_predictions.png')
)
# 7. Feature space visualization
print('\n7. Visualizing feature space (t-SNE)...')
visualize_feature_space(
trainer.model,
val_loader,
device,
max_samples=500,
save_path=os.path.join(vis_dir, 'feature_space_tsne.png')
)
# 8. Correlation heatmap
print('\n8. Generating correlation heatmap...')
generate_correlation_heatmap(
trainer.model,
val_loader,
device,
num_samples=20,
save_path=os.path.join(vis_dir, 'correlation_heatmap.png')
)
# =============================================================================
# FINAL RESULTS SUMMARY
# =============================================================================
print(f'\n{"="*70}')
print('FINAL RESULTS SUMMARY')
print('ИТОГОВЫЕ РЕЗУЛЬТАТЫ')
print(f'{"="*70}')
print(f'\nBest Validation Loss: {trainer.best_val_loss:.4f}')
print(f'Final Validation Accuracy: {history["val_accuracy"][-1]:.4f}')
print(f'Final Validation F1 Score: {history["val_f1"][-1]:.4f}')
print(f'Final Validation Precision: {history["val_precision"][-1]:.4f}')
print(f'Final Validation Recall: {history["val_recall"][-1]:.4f}')
print(f'ROC AUC Score: {roc_auc:.4f}')
print(f'\nCheckpoints saved to: {os.path.join(output_dir, "checkpoints")}')
print(f'Visualizations saved to: {vis_dir}')
print(f'\n{"="*70}')
print('Training and visualization completed successfully!')
print('Обучение и визуализация завершены успешно!')
print(f'{"="*70}\n')
if __name__ == '__main__':
main()