import os import time from datetime import datetime import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader from torch.utils.tensorboard import SummaryWriter from tqdm import tqdm # ============================================================================= # TRAINING LOOP WITH VISUALIZATION # ============================================================================= class SimilarityTrainer: def __init__( self, model: nn.Module, trainloader: DataLoader, valloader: DataLoader, device: torch.device, config: dict, ): self.model = model.to(device) self.trainloader = trainloader self.valloader = valloader self.device = device self.config = config self.criterion = SimilarityLoss() self.optimizer = optim.Adam( model.parameters(), lr=config.get('learning_rate', 2e-4), betas=(config.get('beta1', 0.5), config.get('beta2', 0.999)) ) self.writer = None self.best_val_loss = float('inf') self.epochs_without_improvement = 0 # Для хранения истории метрик self.history = { 'train_loss': [], 'val_loss': [], 'val_accuracy': [], 'val_precision': [], 'val_recall': [], 'val_f1': [], 'learning_rate': [] } def train_epoch(self, epoch: int) -> dict: """Обучение на одной эпохе""" self.model.train() total_loss = 0 total_samples = 0 all_metrics = [] pbar = tqdm(self.trainloader, desc=f'Epoch {epoch}') for batch_idx, batch in enumerate(pbar): google_img = batch['google_img'].to(self.device) yandex_img = batch['yandex_img'].to(self.device) target = batch['same_domain'].float().to(self.device).unsqueeze(1) self.optimizer.zero_grad() # Forward pass output = self.model(google_img, yandex_img) loss = self.criterion(output, target) # Backward pass loss.backward() # Gradient clipping if self.config.get('grad_clip', None): torch.nn.utils.clip_grad_norm_( self.model.parameters(), self.config['grad_clip'] ) self.optimizer.step() total_loss += loss.item() * google_img.size(0) total_samples += google_img.size(0) # Compute metrics if batch_idx % self.config.get('log_interval', 10) == 0: metrics = self.criterion.compute_metrics(output, target) all_metrics.append(metrics) pbar.set_postfix({ 'loss': f"{loss.item():.4f}", 'acc': f"{metrics['accuracy']:.4f}" }) if self.writer: global_step = epoch * len(self.trainloader) + batch_idx self.writer.add_scalar('train/loss', loss.item(), global_step) self.writer.add_scalar('train/accuracy', metrics['accuracy'], global_step) avg_loss = total_loss / total_samples # Average metrics if all_metrics: avg_metrics = { key: sum(m[key] for m in all_metrics) / len(all_metrics) for key in all_metrics[0].keys() } else: avg_metrics = {} return {'loss': avg_loss, **avg_metrics} def validate(self) -> dict: """Валидация модели""" self.model.eval() total_loss = 0 total_samples = 0 all_metrics = [] # Для ROC и confusion matrix all_predictions = [] all_targets = [] with torch.no_grad(): for batch in tqdm(self.valloader, desc='Validation'): google_img = batch['google_img'].to(self.device) yandex_img = batch['yandex_img'].to(self.device) target = batch['same_domain'].float().to(self.device).unsqueeze(1) output = self.model(google_img, yandex_img) loss = self.criterion(output, target) total_loss += loss.item() * google_img.size(0) total_samples += google_img.size(0) metrics = self.criterion.compute_metrics(output, target) all_metrics.append(metrics) all_predictions.append(output.cpu()) all_targets.append(target.cpu()) avg_loss = total_loss / total_samples avg_metrics = { key: sum(m[key] for m in all_metrics) / len(all_metrics) for key in all_metrics[0].keys() } # Concatenate all predictions and targets all_predictions = torch.cat(all_predictions, dim=0) all_targets = torch.cat(all_targets, dim=0) return { 'loss': avg_loss, **avg_metrics, 'predictions': all_predictions, 'targets': all_targets } def train(self, num_epochs: int): """Основной цикл обучения""" log_dir = os.path.join(self.config.get('output_dir', 'runs/similarity')) os.makedirs(log_dir, exist_ok=True) self.writer = SummaryWriter(log_dir) print(f'\n{"="*70}') print(f'Starting training for {num_epochs} epochs') print(f'Logging to {log_dir}') print(f'{"="*70}\n') start_time = time.time() for epoch in range(1, num_epochs + 1): epoch_start = time.time() print(f'\n--- Epoch {epoch}/{num_epochs} ---') # Train train_metrics = self.train_epoch(epoch) # Validate val_metrics = self.validate() # Store history self.history['train_loss'].append(train_metrics['loss']) self.history['val_loss'].append(val_metrics['loss']) self.history['val_accuracy'].append(val_metrics['accuracy']) self.history['val_precision'].append(val_metrics['precision']) self.history['val_recall'].append(val_metrics['recall']) self.history['val_f1'].append(val_metrics['f1']) self.history['learning_rate'].append( self.optimizer.param_groups[0]['lr'] ) # Print metrics print(f'\nTrain Loss: {train_metrics["loss"]:.4f}') print(f'Val Loss: {val_metrics["loss"]:.4f}') print(f'Val Accuracy: {val_metrics["accuracy"]:.4f}') print(f'Val Precision: {val_metrics["precision"]:.4f}') print(f'Val Recall: {val_metrics["recall"]:.4f}') print(f'Val F1: {val_metrics["f1"]:.4f}') epoch_time = time.time() - epoch_start print(f'Epoch time: {epoch_time:.2f}s') # TensorBoard logging if self.writer: self.writer.add_scalar('epoch/train_loss', train_metrics['loss'], epoch) self.writer.add_scalar('epoch/val_loss', val_metrics['loss'], epoch) self.writer.add_scalar('epoch/val_accuracy', val_metrics['accuracy'], epoch) self.writer.add_scalar('epoch/val_precision', val_metrics['precision'], epoch) self.writer.add_scalar('epoch/val_recall', val_metrics['recall'], epoch) self.writer.add_scalar('epoch/val_f1', val_metrics['f1'], epoch) # Save checkpoint if val_metrics['loss'] < self.best_val_loss: self.best_val_loss = val_metrics['loss'] self.epochs_without_improvement = 0 self.save_checkpoint(epoch, val_metrics['loss'], is_best=True) print(f'✓ New best model saved with val loss: {val_metrics["loss"]:.4f}') else: self.epochs_without_improvement += 1 if epoch % self.config.get('save_interval', 5) == 0: self.save_checkpoint(epoch, val_metrics['loss'], is_best=False) # Early stopping patience = self.config.get('early_stopping_patience', 20) if self.epochs_without_improvement >= patience: print(f'\n⚠ Early stopping triggered after {patience} epochs without improvement') break total_time = time.time() - start_time print(f'\n{"="*70}') print(f'Training completed in {total_time/60:.2f} minutes') print(f'Best validation loss: {self.best_val_loss:.4f}') print(f'{"="*70}\n') self.writer.close() return self.history def save_checkpoint(self, epoch: int, val_loss: float, is_best: bool = False): """Сохранение чекпоинта модели""" checkpoint_dir = os.path.join( self.config.get('output_dir', 'runs/similarity'), 'checkpoints' ) os.makedirs(checkpoint_dir, exist_ok=True) checkpoint = { 'epoch': epoch, 'model_state_dict': self.model.state_dict(), 'optimizer_state_dict': self.optimizer.state_dict(), 'val_loss': val_loss, 'config': self.config, 'history': self.history } checkpoint_path = os.path.join(checkpoint_dir, f'checkpoint_epoch{epoch}.pt') torch.save(checkpoint, checkpoint_path) if is_best: best_path = os.path.join(checkpoint_dir, 'best_model.pt') torch.save(checkpoint, best_path) def load_checkpoint(self, checkpoint_path: str): """Загрузка чекпоинта""" checkpoint = torch.load(checkpoint_path, map_location=self.device) self.model.load_state_dict(checkpoint['model_state_dict']) self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) if 'history' in checkpoint: self.history = checkpoint['history'] return checkpoint['epoch'], checkpoint['val_loss'] # ============================================================================= # VISUALIZATION FUNCTIONS # ============================================================================= def plot_training_history(history: dict, save_path: str = None): """Построение графиков обучения""" fig, axes = plt.subplots(2, 3, figsize=(18, 10)) fig.suptitle('Training History - Siamese Network для корреляции снимков', fontsize=16, fontweight='bold') epochs = range(1, len(history['train_loss']) + 1) # Loss axes[0, 0].plot(epochs, history['train_loss'], 'b-', label='Train Loss', linewidth=2) axes[0, 0].plot(epochs, history['val_loss'], 'r-', label='Val Loss', linewidth=2) axes[0, 0].set_xlabel('Epoch') axes[0, 0].set_ylabel('Loss') axes[0, 0].set_title('Loss Curves') axes[0, 0].legend() axes[0, 0].grid(True, alpha=0.3) # Accuracy axes[0, 1].plot(epochs, history['val_accuracy'], 'g-', linewidth=2) axes[0, 1].set_xlabel('Epoch') axes[0, 1].set_ylabel('Accuracy') axes[0, 1].set_title('Validation Accuracy') axes[0, 1].grid(True, alpha=0.3) axes[0, 1].set_ylim([0, 1]) # F1 Score axes[0, 2].plot(epochs, history['val_f1'], 'm-', linewidth=2) axes[0, 2].set_xlabel('Epoch') axes[0, 2].set_ylabel('F1 Score') axes[0, 2].set_title('Validation F1 Score') axes[0, 2].grid(True, alpha=0.3) axes[0, 2].set_ylim([0, 1]) # Precision axes[1, 0].plot(epochs, history['val_precision'], 'c-', linewidth=2) axes[1, 0].set_xlabel('Epoch') axes[1, 0].set_ylabel('Precision') axes[1, 0].set_title('Validation Precision') axes[1, 0].grid(True, alpha=0.3) axes[1, 0].set_ylim([0, 1]) # Recall axes[1, 1].plot(epochs, history['val_recall'], 'y-', linewidth=2) axes[1, 1].set_xlabel('Epoch') axes[1, 1].set_ylabel('Recall') axes[1, 1].set_title('Validation Recall') axes[1, 1].grid(True, alpha=0.3) axes[1, 1].set_ylim([0, 1]) # Learning Rate axes[1, 2].plot(epochs, history['learning_rate'], 'k-', linewidth=2) axes[1, 2].set_xlabel('Epoch') axes[1, 2].set_ylabel('Learning Rate') axes[1, 2].set_title('Learning Rate Schedule') axes[1, 2].grid(True, alpha=0.3) axes[1, 2].set_yscale('log') plt.tight_layout() if save_path: plt.savefig(save_path, dpi=300, bbox_inches='tight') print(f'Training history plot saved to {save_path}') plt.show() def plot_roc_curve(predictions: torch.Tensor, targets: torch.Tensor, save_path: str = None): """Построение ROC кривой""" from sklearn.metrics import roc_curve, auc predictions_np = predictions.numpy().flatten() targets_np = targets.numpy().flatten() fpr, tpr, thresholds = roc_curve(targets_np, predictions_np) roc_auc = auc(fpr, tpr) plt.figure(figsize=(10, 8)) plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (AUC = {roc_auc:.3f})') plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--', label='Random') plt.xlim([0.0, 1.0]) plt.ylim([0.0, 1.05]) plt.xlabel('False Positive Rate', fontsize=12) plt.ylabel('True Positive Rate', fontsize=12) plt.title('ROC Curve - Siamese Network', fontsize=14, fontweight='bold') plt.legend(loc="lower right", fontsize=12) plt.grid(True, alpha=0.3) if save_path: plt.savefig(save_path, dpi=300, bbox_inches='tight') print(f'ROC curve saved to {save_path}') plt.show() return roc_auc def plot_confusion_matrix(predictions: torch.Tensor, targets: torch.Tensor, threshold: float = 0.5, save_path: str = None): """Построение матрицы ошибок""" from sklearn.metrics import confusion_matrix predictions_binary = (predictions.numpy().flatten() >= threshold).astype(int) targets_binary = (targets.numpy().flatten() >= 0.5).astype(int) cm = confusion_matrix(targets_binary, predictions_binary) plt.figure(figsize=(10, 8)) plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues) plt.title('Confusion Matrix - Correlation Detection', fontsize=14, fontweight='bold') plt.colorbar() classes = ['Different Domains', 'Same Domain'] tick_marks = np.arange(len(classes)) plt.xticks(tick_marks, classes, rotation=45, fontsize=12) plt.yticks(tick_marks, classes, fontsize=12) # Add text annotations thresh = cm.max() / 2. for i in range(cm.shape[0]): for j in range(cm.shape[1]): plt.text(j, i, format(cm[i, j], 'd'), ha="center", va="center", color="white" if cm[i, j] > thresh else "black", fontsize=16, fontweight='bold') plt.ylabel('True Label', fontsize=12) plt.xlabel('Predicted Label', fontsize=12) plt.tight_layout() if save_path: plt.savefig(save_path, dpi=300, bbox_inches='tight') print(f'Confusion matrix saved to {save_path}') plt.show() def plot_similarity_distribution(predictions: torch.Tensor, targets: torch.Tensor, save_path: str = None): """Распределение предсказанных значений схожести""" predictions_np = predictions.numpy().flatten() targets_np = targets.numpy().flatten() same_domain = predictions_np[targets_np >= 0.5] diff_domain = predictions_np[targets_np < 0.5] plt.figure(figsize=(12, 6)) plt.subplot(1, 2, 1) plt.hist(same_domain, bins=50, alpha=0.7, color='green', edgecolor='black', label='Same Domain') plt.hist(diff_domain, bins=50, alpha=0.7, color='red', edgecolor='black', label='Different Domains') plt.xlabel('Predicted Similarity Score', fontsize=12) plt.ylabel('Frequency', fontsize=12) plt.title('Distribution of Similarity Scores', fontsize=14, fontweight='bold') plt.legend(fontsize=11) plt.grid(True, alpha=0.3) plt.subplot(1, 2, 2) plt.boxplot([diff_domain, same_domain], labels=['Different', 'Same']) plt.ylabel('Similarity Score', fontsize=12) plt.xlabel('Domain Match', fontsize=12) plt.title('Similarity Score by Domain Match', fontsize=14, fontweight='bold') plt.grid(True, alpha=0.3, axis='y') plt.tight_layout() if save_path: plt.savefig(save_path, dpi=300, bbox_inches='tight') print(f'Similarity distribution plot saved to {save_path}') plt.show() # Print statistics print(f'\n--- Similarity Score Statistics ---') print(f'Same Domain:') print(f' Mean: {same_domain.mean():.4f}') print(f' Std: {same_domain.std():.4f}') print(f' Min: {same_domain.min():.4f}') print(f' Max: {same_domain.max():.4f}') print(f'\nDifferent Domains:') print(f' Mean: {diff_domain.mean():.4f}') print(f' Std: {diff_domain.std():.4f}') print(f' Min: {diff_domain.min():.4f}') print(f' Max: {diff_domain.max():.4f}') def visualize_sample_predictions(model: nn.Module, dataset, device: torch.device, num_samples: int = 8, save_path: str = None): """Визуализация примеров предсказаний""" model.eval() # Get random samples indices = np.random.choice(len(dataset), num_samples, replace=False) fig, axes = plt.subplots(num_samples, 3, figsize=(15, 5*num_samples)) if num_samples == 1: axes = axes.reshape(1, -1) fig.suptitle('Sample Predictions - Siamese Network для корреляции карт', fontsize=16, fontweight='bold') with torch.no_grad(): for idx, sample_idx in enumerate(indices): sample = dataset[sample_idx] google_img = sample['google_img'].unsqueeze(0).to(device) yandex_img = sample['yandex_img'].unsqueeze(0).to(device) true_label = sample['same_domain'].item() # Predict pred_similarity = model(google_img, yandex_img).item() pred_label = int(pred_similarity >= 0.5) # Denormalize images for visualization google_np = google_img.squeeze(0).cpu().numpy().transpose(1, 2, 0) yandex_np = yandex_img.squeeze(0).cpu().numpy().transpose(1, 2, 0) # Denormalize (assuming ImageNet normalization) mean = np.array([0.485, 0.456, 0.406]) std = np.array([0.229, 0.224, 0.225]) google_np = std * google_np + mean yandex_np = std * yandex_np + mean google_np = np.clip(google_np, 0, 1) yandex_np = np.clip(yandex_np, 0, 1) # Plot Google image axes[idx, 0].imshow(google_np) axes[idx, 0].set_title('Google Map', fontsize=12, fontweight='bold') axes[idx, 0].axis('off') # Plot Yandex image axes[idx, 1].imshow(yandex_np) axes[idx, 1].set_title('Yandex Map', fontsize=12, fontweight='bold') axes[idx, 1].axis('off') # Plot prediction info axes[idx, 2].axis('off') # Determine color based on correctness is_correct = (pred_label == true_label) color = 'green' if is_correct else 'red' result = '✓ Correct' if is_correct else '✗ Incorrect' info_text = f""" Prediction: {pred_similarity:.4f} Predicted Label: {'Same' if pred_label == 1 else 'Different'} True Label: {'Same' if true_label == 1 else 'Different'} {result} """ axes[idx, 2].text(0.5, 0.5, info_text, ha='center', va='center', fontsize=12, bbox=dict(boxstyle='round', facecolor=color, alpha=0.2), transform=axes[idx, 2].transAxes) plt.tight_layout() if save_path: plt.savefig(save_path, dpi=300, bbox_inches='tight') print(f'Sample predictions saved to {save_path}') plt.show() def visualize_feature_space(model: nn.Module, dataloader: DataLoader, device: torch.device, max_samples: int = 500, save_path: str = None): """Визуализация пространства признаков с помощью t-SNE""" from sklearn.manifold import TSNE model.eval() all_features_google = [] all_features_yandex = [] all_labels = [] with torch.no_grad(): for i, batch in enumerate(tqdm(dataloader, desc='Extracting features')): if i * dataloader.batch_size >= max_samples: break google_img = batch['google_img'].to(device) yandex_img = batch['yandex_img'].to(device) labels = batch['same_domain'].cpu().numpy() # Extract features features_google = model.extract_features(google_img).cpu().numpy() features_yandex = model.extract_features(yandex_img).cpu().numpy() all_features_google.append(features_google) all_features_yandex.append(features_yandex) all_labels.append(labels) all_features_google = np.concatenate(all_features_google, axis=0) all_features_yandex = np.concatenate(all_features_yandex, axis=0) all_labels = np.concatenate(all_labels, axis=0) # Combine features all_features = np.concatenate([all_features_google, all_features_yandex], axis=0) all_labels = np.concatenate([all_labels, all_labels], axis=0) print(f'\nApplying t-SNE to {all_features.shape[0]} samples...') tsne = TSNE(n_components=2, random_state=42, perplexity=30) features_2d = tsne.fit_transform(all_features) # Split back into Google and Yandex n_samples = len(all_labels) // 2 features_google_2d = features_2d[:n_samples] features_yandex_2d = features_2d[n_samples:] labels = all_labels[:n_samples] # Plot fig, axes = plt.subplots(1, 2, figsize=(20, 8)) # Google features for label in [0, 1]: mask = labels == label axes[0].scatter( features_google_2d[mask, 0], features_google_2d[mask, 1], c='green' if label == 1 else 'red', label='Same Domain' if label == 1 else 'Different Domains', alpha=0.6, s=50 ) axes[0].set_title('Google Maps Features (t-SNE)', fontsize=14, fontweight='bold') axes[0].set_xlabel('t-SNE Component 1', fontsize=12) axes[0].set_ylabel('t-SNE Component 2', fontsize=12) axes[0].legend(fontsize=11) axes[0].grid(True, alpha=0.3) # Yandex features for label in [0, 1]: mask = labels == label axes[1].scatter( features_yandex_2d[mask, 0], features_yandex_2d[mask, 1], c='green' if label == 1 else 'red', label='Same Domain' if label == 1 else 'Different Domains', alpha=0.6, s=50 ) axes[1].set_title('Yandex Maps Features (t-SNE)', fontsize=14, fontweight='bold') axes[1].set_xlabel('t-SNE Component 1', fontsize=12) axes[1].set_ylabel('t-SNE Component 2', fontsize=12) axes[1].legend(fontsize=11) axes[1].grid(True, alpha=0.3) plt.tight_layout() if save_path: plt.savefig(save_path, dpi=300, bbox_inches='tight') print(f'Feature space visualization saved to {save_path}') plt.show() def generate_correlation_heatmap(model: nn.Module, dataloader: DataLoader, device: torch.device, num_samples: int = 20, save_path: str = None): """Создание тепловой карты корреляций между снимками""" model.eval() # Collect samples google_images = [] yandex_images = [] labels = [] with torch.no_grad(): for i, batch in enumerate(dataloader): if len(google_images) >= num_samples: break google_img = batch['google_img'].to(device) yandex_img = batch['yandex_img'].to(device) label = batch['same_domain'] google_images.append(google_img[:1]) yandex_images.append(yandex_img[:1]) labels.append(label[:1].item()) google_images = torch.cat(google_images[:num_samples], dim=0) yandex_images = torch.cat(yandex_images[:num_samples], dim=0) # Compute similarity matrix similarity_matrix = np.zeros((num_samples, num_samples)) with torch.no_grad(): for i in tqdm(range(num_samples), desc='Computing correlations'): for j in range(num_samples): google_i = google_images[i:i+1] yandex_j = yandex_images[j:j+1] similarity = model(google_i, yandex_j).item() similarity_matrix[i, j] = similarity # Plot heatmap plt.figure(figsize=(14, 12)) im = plt.imshow(similarity_matrix, cmap='RdYlGn', aspect='auto', vmin=0, vmax=1) plt.colorbar(im, label='Similarity Score', fraction=0.046, pad=0.04) plt.title('Correlation Heatmap: Google vs Yandex Maps\n(Матрица корреляций снимков)', fontsize=16, fontweight='bold', pad=20) plt.xlabel('Yandex Map Index', fontsize=12) plt.ylabel('Google Map Index', fontsize=12) # Add grid plt.grid(True, which='both', color='gray', linestyle='-', linewidth=0.5, alpha=0.3) # Add text annotations for diagonal (same pairs) for i in range(min(num_samples, 10)): # Annotate first 10 for readability if labels[i] == 1: # True match plt.text(i, i, '✓', ha='center', va='center', color='white', fontsize=12, fontweight='bold') plt.tight_layout() if save_path: plt.savefig(save_path, dpi=300, bbox_inches='tight') print(f'Correlation heatmap saved to {save_path}') plt.show() # Print statistics diagonal = np.diag(similarity_matrix) off_diagonal = similarity_matrix[~np.eye(num_samples, dtype=bool)] print(f'\n--- Correlation Statistics ---') print(f'Diagonal (matched pairs):') print(f' Mean: {diagonal.mean():.4f}') print(f' Std: {diagonal.std():.4f}') print(f'\nOff-diagonal (mismatched pairs):') print(f' Mean: {off_diagonal.mean():.4f}') print(f' Std: {off_diagonal.std():.4f}') # ============================================================================= # MAIN TRAINING SCRIPT # ============================================================================= def main(): """Основная функция обучения""" # Configuration config_dict = config.copy() if isinstance(config_dict.get('image_size'), list): config_dict['image_size'] = tuple(config_dict['image_size']) # Device device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f'\n{"="*70}') print(f'Siamese Network Training for Map Correlation') print(f'Обучение сиамской сети для корреляции снимков') print(f'{"="*70}') print(f'Using device: {device}') if torch.cuda.is_available(): print(f'GPU: {torch.cuda.get_device_name(0)}') print(f'GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB') # Create data loaders print(f'\n{"="*70}') print('Creating data loaders...') print(f'{"="*70}') train_loader, val_loader = create_data_loaders( root_dir=config_dict['data_dir'], batch_size=config_dict['batch_size'], train_split=config_dict['train_split'], num_workers=config_dict['num_workers'], image_size=config_dict['image_size'], augment_train=True, augment_val=False, device=device ) print(f'Train batches: {len(train_loader)}') print(f'Val batches: {len(val_loader)}') print(f'Train samples: {len(train_loader.dataset)}') print(f'Val samples: {len(val_loader.dataset)}') # Create model print(f'\n{"="*70}') print('Creating model...') print(f'{"="*70}') model = create_similarity_model( model_type='backbone', input_size=config_dict['image_size'][0] if isinstance(config_dict['image_size'], (tuple, list)) else config_dict['image_size'], input_channels=3, backbone_name='resnet18', pretrained=True, dropout_rate=0.3, use_batch_norm=True ) total_params = sum(p.numel() for p in model.parameters()) trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) print(f'Total parameters: {total_params:,}') print(f'Trainable parameters: {trainable_params:,}') # Create trainer trainer = SimilarityTrainer( model=model, trainloader=train_loader, valloader=val_loader, device=device, config=config_dict ) # Train model print(f'\n{"="*70}') print('Starting training...') print(f'{"="*70}') history = trainer.train(config_dict['epochs']) # ============================================================================= # VISUALIZATION AND RESULTS # ============================================================================= print(f'\n{"="*70}') print('Generating visualizations...') print(f'{"="*70}') output_dir = config_dict.get('output_dir', 'runs/similarity') vis_dir = os.path.join(output_dir, 'visualizations') os.makedirs(vis_dir, exist_ok=True) # 1. Training history print('\n1. Plotting training history...') plot_training_history( history, save_path=os.path.join(vis_dir, 'training_history.png') ) # 2. Validation metrics print('\n2. Computing validation predictions...') trainer.model.eval() val_predictions = [] val_targets = [] with torch.no_grad(): for batch in tqdm(val_loader, desc='Validation'): google_img = batch['google_img'].to(device) yandex_img = batch['yandex_img'].to(device) target = batch['same_domain'].float().unsqueeze(1) output = trainer.model(google_img, yandex_img) val_predictions.append(output.cpu()) val_targets.append(target.cpu()) val_predictions = torch.cat(val_predictions, dim=0) val_targets = torch.cat(val_targets, dim=0) # 3. ROC curve print('\n3. Plotting ROC curve...') roc_auc = plot_roc_curve( val_predictions, val_targets, save_path=os.path.join(vis_dir, 'roc_curve.png') ) print(f'ROC AUC Score: {roc_auc:.4f}') # 4. Confusion matrix print('\n4. Plotting confusion matrix...') plot_confusion_matrix( val_predictions, val_targets, threshold=0.5, save_path=os.path.join(vis_dir, 'confusion_matrix.png') ) # 5. Similarity distribution print('\n5. Plotting similarity distribution...') plot_similarity_distribution( val_predictions, val_targets, save_path=os.path.join(vis_dir, 'similarity_distribution.png') ) # 6. Sample predictions print('\n6. Visualizing sample predictions...') visualize_sample_predictions( trainer.model, val_loader.dataset, device, num_samples=8, save_path=os.path.join(vis_dir, 'sample_predictions.png') ) # 7. Feature space visualization print('\n7. Visualizing feature space (t-SNE)...') visualize_feature_space( trainer.model, val_loader, device, max_samples=500, save_path=os.path.join(vis_dir, 'feature_space_tsne.png') ) # 8. Correlation heatmap print('\n8. Generating correlation heatmap...') generate_correlation_heatmap( trainer.model, val_loader, device, num_samples=20, save_path=os.path.join(vis_dir, 'correlation_heatmap.png') ) # ============================================================================= # FINAL RESULTS SUMMARY # ============================================================================= print(f'\n{"="*70}') print('FINAL RESULTS SUMMARY') print('ИТОГОВЫЕ РЕЗУЛЬТАТЫ') print(f'{"="*70}') print(f'\nBest Validation Loss: {trainer.best_val_loss:.4f}') print(f'Final Validation Accuracy: {history["val_accuracy"][-1]:.4f}') print(f'Final Validation F1 Score: {history["val_f1"][-1]:.4f}') print(f'Final Validation Precision: {history["val_precision"][-1]:.4f}') print(f'Final Validation Recall: {history["val_recall"][-1]:.4f}') print(f'ROC AUC Score: {roc_auc:.4f}') print(f'\nCheckpoints saved to: {os.path.join(output_dir, "checkpoints")}') print(f'Visualizations saved to: {vis_dir}') print(f'\n{"="*70}') print('Training and visualization completed successfully!') print('Обучение и визуализация завершены успешно!') print(f'{"="*70}\n') if __name__ == '__main__': main()