Files
autopilot/models/SiaN-similarity/train-adv.py

917 lines
33 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import os
import time
from datetime import datetime
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
# =============================================================================
# TRAINING LOOP WITH VISUALIZATION
# =============================================================================
class SimilarityTrainer:
def __init__(
self,
model: nn.Module,
trainloader: DataLoader,
valloader: DataLoader,
device: torch.device,
config: dict,
):
self.model = model.to(device)
self.trainloader = trainloader
self.valloader = valloader
self.device = device
self.config = config
self.criterion = SimilarityLoss()
self.optimizer = optim.Adam(
model.parameters(),
lr=config.get('learning_rate', 2e-4),
betas=(config.get('beta1', 0.5), config.get('beta2', 0.999))
)
self.writer = None
self.best_val_loss = float('inf')
self.epochs_without_improvement = 0
# Для хранения истории метрик
self.history = {
'train_loss': [],
'val_loss': [],
'val_accuracy': [],
'val_precision': [],
'val_recall': [],
'val_f1': [],
'learning_rate': []
}
def train_epoch(self, epoch: int) -> dict:
"""Обучение на одной эпохе"""
self.model.train()
total_loss = 0
total_samples = 0
all_metrics = []
pbar = tqdm(self.trainloader, desc=f'Epoch {epoch}')
for batch_idx, batch in enumerate(pbar):
google_img = batch['google_img'].to(self.device)
yandex_img = batch['yandex_img'].to(self.device)
target = batch['same_domain'].float().to(self.device).unsqueeze(1)
self.optimizer.zero_grad()
# Forward pass
output = self.model(google_img, yandex_img)
loss = self.criterion(output, target)
# Backward pass
loss.backward()
# Gradient clipping
if self.config.get('grad_clip', None):
torch.nn.utils.clip_grad_norm_(
self.model.parameters(),
self.config['grad_clip']
)
self.optimizer.step()
total_loss += loss.item() * google_img.size(0)
total_samples += google_img.size(0)
# Compute metrics
if batch_idx % self.config.get('log_interval', 10) == 0:
metrics = self.criterion.compute_metrics(output, target)
all_metrics.append(metrics)
pbar.set_postfix({
'loss': f"{loss.item():.4f}",
'acc': f"{metrics['accuracy']:.4f}"
})
if self.writer:
global_step = epoch * len(self.trainloader) + batch_idx
self.writer.add_scalar('train/loss', loss.item(), global_step)
self.writer.add_scalar('train/accuracy', metrics['accuracy'], global_step)
avg_loss = total_loss / total_samples
# Average metrics
if all_metrics:
avg_metrics = {
key: sum(m[key] for m in all_metrics) / len(all_metrics)
for key in all_metrics[0].keys()
}
else:
avg_metrics = {}
return {'loss': avg_loss, **avg_metrics}
def validate(self) -> dict:
"""Валидация модели"""
self.model.eval()
total_loss = 0
total_samples = 0
all_metrics = []
# Для ROC и confusion matrix
all_predictions = []
all_targets = []
with torch.no_grad():
for batch in tqdm(self.valloader, desc='Validation'):
google_img = batch['google_img'].to(self.device)
yandex_img = batch['yandex_img'].to(self.device)
target = batch['same_domain'].float().to(self.device).unsqueeze(1)
output = self.model(google_img, yandex_img)
loss = self.criterion(output, target)
total_loss += loss.item() * google_img.size(0)
total_samples += google_img.size(0)
metrics = self.criterion.compute_metrics(output, target)
all_metrics.append(metrics)
all_predictions.append(output.cpu())
all_targets.append(target.cpu())
avg_loss = total_loss / total_samples
avg_metrics = {
key: sum(m[key] for m in all_metrics) / len(all_metrics)
for key in all_metrics[0].keys()
}
# Concatenate all predictions and targets
all_predictions = torch.cat(all_predictions, dim=0)
all_targets = torch.cat(all_targets, dim=0)
return {
'loss': avg_loss,
**avg_metrics,
'predictions': all_predictions,
'targets': all_targets
}
def train(self, num_epochs: int):
"""Основной цикл обучения"""
log_dir = os.path.join(self.config.get('output_dir', 'runs/similarity'))
os.makedirs(log_dir, exist_ok=True)
self.writer = SummaryWriter(log_dir)
print(f'\n{"="*70}')
print(f'Starting training for {num_epochs} epochs')
print(f'Logging to {log_dir}')
print(f'{"="*70}\n')
start_time = time.time()
for epoch in range(1, num_epochs + 1):
epoch_start = time.time()
print(f'\n--- Epoch {epoch}/{num_epochs} ---')
# Train
train_metrics = self.train_epoch(epoch)
# Validate
val_metrics = self.validate()
# Store history
self.history['train_loss'].append(train_metrics['loss'])
self.history['val_loss'].append(val_metrics['loss'])
self.history['val_accuracy'].append(val_metrics['accuracy'])
self.history['val_precision'].append(val_metrics['precision'])
self.history['val_recall'].append(val_metrics['recall'])
self.history['val_f1'].append(val_metrics['f1'])
self.history['learning_rate'].append(
self.optimizer.param_groups[0]['lr']
)
# Print metrics
print(f'\nTrain Loss: {train_metrics["loss"]:.4f}')
print(f'Val Loss: {val_metrics["loss"]:.4f}')
print(f'Val Accuracy: {val_metrics["accuracy"]:.4f}')
print(f'Val Precision: {val_metrics["precision"]:.4f}')
print(f'Val Recall: {val_metrics["recall"]:.4f}')
print(f'Val F1: {val_metrics["f1"]:.4f}')
epoch_time = time.time() - epoch_start
print(f'Epoch time: {epoch_time:.2f}s')
# TensorBoard logging
if self.writer:
self.writer.add_scalar('epoch/train_loss', train_metrics['loss'], epoch)
self.writer.add_scalar('epoch/val_loss', val_metrics['loss'], epoch)
self.writer.add_scalar('epoch/val_accuracy', val_metrics['accuracy'], epoch)
self.writer.add_scalar('epoch/val_precision', val_metrics['precision'], epoch)
self.writer.add_scalar('epoch/val_recall', val_metrics['recall'], epoch)
self.writer.add_scalar('epoch/val_f1', val_metrics['f1'], epoch)
# Save checkpoint
if val_metrics['loss'] < self.best_val_loss:
self.best_val_loss = val_metrics['loss']
self.epochs_without_improvement = 0
self.save_checkpoint(epoch, val_metrics['loss'], is_best=True)
print(f'✓ New best model saved with val loss: {val_metrics["loss"]:.4f}')
else:
self.epochs_without_improvement += 1
if epoch % self.config.get('save_interval', 5) == 0:
self.save_checkpoint(epoch, val_metrics['loss'], is_best=False)
# Early stopping
patience = self.config.get('early_stopping_patience', 20)
if self.epochs_without_improvement >= patience:
print(f'\n⚠ Early stopping triggered after {patience} epochs without improvement')
break
total_time = time.time() - start_time
print(f'\n{"="*70}')
print(f'Training completed in {total_time/60:.2f} minutes')
print(f'Best validation loss: {self.best_val_loss:.4f}')
print(f'{"="*70}\n')
self.writer.close()
return self.history
def save_checkpoint(self, epoch: int, val_loss: float, is_best: bool = False):
"""Сохранение чекпоинта модели"""
checkpoint_dir = os.path.join(
self.config.get('output_dir', 'runs/similarity'),
'checkpoints'
)
os.makedirs(checkpoint_dir, exist_ok=True)
checkpoint = {
'epoch': epoch,
'model_state_dict': self.model.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'val_loss': val_loss,
'config': self.config,
'history': self.history
}
checkpoint_path = os.path.join(checkpoint_dir, f'checkpoint_epoch{epoch}.pt')
torch.save(checkpoint, checkpoint_path)
if is_best:
best_path = os.path.join(checkpoint_dir, 'best_model.pt')
torch.save(checkpoint, best_path)
def load_checkpoint(self, checkpoint_path: str):
"""Загрузка чекпоинта"""
checkpoint = torch.load(checkpoint_path, map_location=self.device)
self.model.load_state_dict(checkpoint['model_state_dict'])
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
if 'history' in checkpoint:
self.history = checkpoint['history']
return checkpoint['epoch'], checkpoint['val_loss']
# =============================================================================
# VISUALIZATION FUNCTIONS
# =============================================================================
def plot_training_history(history: dict, save_path: str = None):
"""Построение графиков обучения"""
fig, axes = plt.subplots(2, 3, figsize=(18, 10))
fig.suptitle('Training History - Siamese Network для корреляции снимков',
fontsize=16, fontweight='bold')
epochs = range(1, len(history['train_loss']) + 1)
# Loss
axes[0, 0].plot(epochs, history['train_loss'], 'b-', label='Train Loss', linewidth=2)
axes[0, 0].plot(epochs, history['val_loss'], 'r-', label='Val Loss', linewidth=2)
axes[0, 0].set_xlabel('Epoch')
axes[0, 0].set_ylabel('Loss')
axes[0, 0].set_title('Loss Curves')
axes[0, 0].legend()
axes[0, 0].grid(True, alpha=0.3)
# Accuracy
axes[0, 1].plot(epochs, history['val_accuracy'], 'g-', linewidth=2)
axes[0, 1].set_xlabel('Epoch')
axes[0, 1].set_ylabel('Accuracy')
axes[0, 1].set_title('Validation Accuracy')
axes[0, 1].grid(True, alpha=0.3)
axes[0, 1].set_ylim([0, 1])
# F1 Score
axes[0, 2].plot(epochs, history['val_f1'], 'm-', linewidth=2)
axes[0, 2].set_xlabel('Epoch')
axes[0, 2].set_ylabel('F1 Score')
axes[0, 2].set_title('Validation F1 Score')
axes[0, 2].grid(True, alpha=0.3)
axes[0, 2].set_ylim([0, 1])
# Precision
axes[1, 0].plot(epochs, history['val_precision'], 'c-', linewidth=2)
axes[1, 0].set_xlabel('Epoch')
axes[1, 0].set_ylabel('Precision')
axes[1, 0].set_title('Validation Precision')
axes[1, 0].grid(True, alpha=0.3)
axes[1, 0].set_ylim([0, 1])
# Recall
axes[1, 1].plot(epochs, history['val_recall'], 'y-', linewidth=2)
axes[1, 1].set_xlabel('Epoch')
axes[1, 1].set_ylabel('Recall')
axes[1, 1].set_title('Validation Recall')
axes[1, 1].grid(True, alpha=0.3)
axes[1, 1].set_ylim([0, 1])
# Learning Rate
axes[1, 2].plot(epochs, history['learning_rate'], 'k-', linewidth=2)
axes[1, 2].set_xlabel('Epoch')
axes[1, 2].set_ylabel('Learning Rate')
axes[1, 2].set_title('Learning Rate Schedule')
axes[1, 2].grid(True, alpha=0.3)
axes[1, 2].set_yscale('log')
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches='tight')
print(f'Training history plot saved to {save_path}')
plt.show()
def plot_roc_curve(predictions: torch.Tensor, targets: torch.Tensor, save_path: str = None):
"""Построение ROC кривой"""
from sklearn.metrics import roc_curve, auc
predictions_np = predictions.numpy().flatten()
targets_np = targets.numpy().flatten()
fpr, tpr, thresholds = roc_curve(targets_np, predictions_np)
roc_auc = auc(fpr, tpr)
plt.figure(figsize=(10, 8))
plt.plot(fpr, tpr, color='darkorange', lw=2,
label=f'ROC curve (AUC = {roc_auc:.3f})')
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--', label='Random')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate', fontsize=12)
plt.ylabel('True Positive Rate', fontsize=12)
plt.title('ROC Curve - Siamese Network', fontsize=14, fontweight='bold')
plt.legend(loc="lower right", fontsize=12)
plt.grid(True, alpha=0.3)
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches='tight')
print(f'ROC curve saved to {save_path}')
plt.show()
return roc_auc
def plot_confusion_matrix(predictions: torch.Tensor, targets: torch.Tensor,
threshold: float = 0.5, save_path: str = None):
"""Построение матрицы ошибок"""
from sklearn.metrics import confusion_matrix
predictions_binary = (predictions.numpy().flatten() >= threshold).astype(int)
targets_binary = (targets.numpy().flatten() >= 0.5).astype(int)
cm = confusion_matrix(targets_binary, predictions_binary)
plt.figure(figsize=(10, 8))
plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
plt.title('Confusion Matrix - Correlation Detection', fontsize=14, fontweight='bold')
plt.colorbar()
classes = ['Different Domains', 'Same Domain']
tick_marks = np.arange(len(classes))
plt.xticks(tick_marks, classes, rotation=45, fontsize=12)
plt.yticks(tick_marks, classes, fontsize=12)
# Add text annotations
thresh = cm.max() / 2.
for i in range(cm.shape[0]):
for j in range(cm.shape[1]):
plt.text(j, i, format(cm[i, j], 'd'),
ha="center", va="center",
color="white" if cm[i, j] > thresh else "black",
fontsize=16, fontweight='bold')
plt.ylabel('True Label', fontsize=12)
plt.xlabel('Predicted Label', fontsize=12)
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches='tight')
print(f'Confusion matrix saved to {save_path}')
plt.show()
def plot_similarity_distribution(predictions: torch.Tensor, targets: torch.Tensor,
save_path: str = None):
"""Распределение предсказанных значений схожести"""
predictions_np = predictions.numpy().flatten()
targets_np = targets.numpy().flatten()
same_domain = predictions_np[targets_np >= 0.5]
diff_domain = predictions_np[targets_np < 0.5]
plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.hist(same_domain, bins=50, alpha=0.7, color='green', edgecolor='black', label='Same Domain')
plt.hist(diff_domain, bins=50, alpha=0.7, color='red', edgecolor='black', label='Different Domains')
plt.xlabel('Predicted Similarity Score', fontsize=12)
plt.ylabel('Frequency', fontsize=12)
plt.title('Distribution of Similarity Scores', fontsize=14, fontweight='bold')
plt.legend(fontsize=11)
plt.grid(True, alpha=0.3)
plt.subplot(1, 2, 2)
plt.boxplot([diff_domain, same_domain], labels=['Different', 'Same'])
plt.ylabel('Similarity Score', fontsize=12)
plt.xlabel('Domain Match', fontsize=12)
plt.title('Similarity Score by Domain Match', fontsize=14, fontweight='bold')
plt.grid(True, alpha=0.3, axis='y')
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches='tight')
print(f'Similarity distribution plot saved to {save_path}')
plt.show()
# Print statistics
print(f'\n--- Similarity Score Statistics ---')
print(f'Same Domain:')
print(f' Mean: {same_domain.mean():.4f}')
print(f' Std: {same_domain.std():.4f}')
print(f' Min: {same_domain.min():.4f}')
print(f' Max: {same_domain.max():.4f}')
print(f'\nDifferent Domains:')
print(f' Mean: {diff_domain.mean():.4f}')
print(f' Std: {diff_domain.std():.4f}')
print(f' Min: {diff_domain.min():.4f}')
print(f' Max: {diff_domain.max():.4f}')
def visualize_sample_predictions(model: nn.Module, dataset, device: torch.device,
num_samples: int = 8, save_path: str = None):
"""Визуализация примеров предсказаний"""
model.eval()
# Get random samples
indices = np.random.choice(len(dataset), num_samples, replace=False)
fig, axes = plt.subplots(num_samples, 3, figsize=(15, 5*num_samples))
if num_samples == 1:
axes = axes.reshape(1, -1)
fig.suptitle('Sample Predictions - Siamese Network для корреляции карт',
fontsize=16, fontweight='bold')
with torch.no_grad():
for idx, sample_idx in enumerate(indices):
sample = dataset[sample_idx]
google_img = sample['google_img'].unsqueeze(0).to(device)
yandex_img = sample['yandex_img'].unsqueeze(0).to(device)
true_label = sample['same_domain'].item()
# Predict
pred_similarity = model(google_img, yandex_img).item()
pred_label = int(pred_similarity >= 0.5)
# Denormalize images for visualization
google_np = google_img.squeeze(0).cpu().numpy().transpose(1, 2, 0)
yandex_np = yandex_img.squeeze(0).cpu().numpy().transpose(1, 2, 0)
# Denormalize (assuming ImageNet normalization)
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
google_np = std * google_np + mean
yandex_np = std * yandex_np + mean
google_np = np.clip(google_np, 0, 1)
yandex_np = np.clip(yandex_np, 0, 1)
# Plot Google image
axes[idx, 0].imshow(google_np)
axes[idx, 0].set_title('Google Map', fontsize=12, fontweight='bold')
axes[idx, 0].axis('off')
# Plot Yandex image
axes[idx, 1].imshow(yandex_np)
axes[idx, 1].set_title('Yandex Map', fontsize=12, fontweight='bold')
axes[idx, 1].axis('off')
# Plot prediction info
axes[idx, 2].axis('off')
# Determine color based on correctness
is_correct = (pred_label == true_label)
color = 'green' if is_correct else 'red'
result = '✓ Correct' if is_correct else '✗ Incorrect'
info_text = f"""
Prediction: {pred_similarity:.4f}
Predicted Label: {'Same' if pred_label == 1 else 'Different'}
True Label: {'Same' if true_label == 1 else 'Different'}
{result}
"""
axes[idx, 2].text(0.5, 0.5, info_text,
ha='center', va='center',
fontsize=12,
bbox=dict(boxstyle='round', facecolor=color, alpha=0.2),
transform=axes[idx, 2].transAxes)
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches='tight')
print(f'Sample predictions saved to {save_path}')
plt.show()
def visualize_feature_space(model: nn.Module, dataloader: DataLoader,
device: torch.device, max_samples: int = 500,
save_path: str = None):
"""Визуализация пространства признаков с помощью t-SNE"""
from sklearn.manifold import TSNE
model.eval()
all_features_google = []
all_features_yandex = []
all_labels = []
with torch.no_grad():
for i, batch in enumerate(tqdm(dataloader, desc='Extracting features')):
if i * dataloader.batch_size >= max_samples:
break
google_img = batch['google_img'].to(device)
yandex_img = batch['yandex_img'].to(device)
labels = batch['same_domain'].cpu().numpy()
# Extract features
features_google = model.extract_features(google_img).cpu().numpy()
features_yandex = model.extract_features(yandex_img).cpu().numpy()
all_features_google.append(features_google)
all_features_yandex.append(features_yandex)
all_labels.append(labels)
all_features_google = np.concatenate(all_features_google, axis=0)
all_features_yandex = np.concatenate(all_features_yandex, axis=0)
all_labels = np.concatenate(all_labels, axis=0)
# Combine features
all_features = np.concatenate([all_features_google, all_features_yandex], axis=0)
all_labels = np.concatenate([all_labels, all_labels], axis=0)
print(f'\nApplying t-SNE to {all_features.shape[0]} samples...')
tsne = TSNE(n_components=2, random_state=42, perplexity=30)
features_2d = tsne.fit_transform(all_features)
# Split back into Google and Yandex
n_samples = len(all_labels) // 2
features_google_2d = features_2d[:n_samples]
features_yandex_2d = features_2d[n_samples:]
labels = all_labels[:n_samples]
# Plot
fig, axes = plt.subplots(1, 2, figsize=(20, 8))
# Google features
for label in [0, 1]:
mask = labels == label
axes[0].scatter(
features_google_2d[mask, 0],
features_google_2d[mask, 1],
c='green' if label == 1 else 'red',
label='Same Domain' if label == 1 else 'Different Domains',
alpha=0.6,
s=50
)
axes[0].set_title('Google Maps Features (t-SNE)', fontsize=14, fontweight='bold')
axes[0].set_xlabel('t-SNE Component 1', fontsize=12)
axes[0].set_ylabel('t-SNE Component 2', fontsize=12)
axes[0].legend(fontsize=11)
axes[0].grid(True, alpha=0.3)
# Yandex features
for label in [0, 1]:
mask = labels == label
axes[1].scatter(
features_yandex_2d[mask, 0],
features_yandex_2d[mask, 1],
c='green' if label == 1 else 'red',
label='Same Domain' if label == 1 else 'Different Domains',
alpha=0.6,
s=50
)
axes[1].set_title('Yandex Maps Features (t-SNE)', fontsize=14, fontweight='bold')
axes[1].set_xlabel('t-SNE Component 1', fontsize=12)
axes[1].set_ylabel('t-SNE Component 2', fontsize=12)
axes[1].legend(fontsize=11)
axes[1].grid(True, alpha=0.3)
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches='tight')
print(f'Feature space visualization saved to {save_path}')
plt.show()
def generate_correlation_heatmap(model: nn.Module, dataloader: DataLoader,
device: torch.device, num_samples: int = 20,
save_path: str = None):
"""Создание тепловой карты корреляций между снимками"""
model.eval()
# Collect samples
google_images = []
yandex_images = []
labels = []
with torch.no_grad():
for i, batch in enumerate(dataloader):
if len(google_images) >= num_samples:
break
google_img = batch['google_img'].to(device)
yandex_img = batch['yandex_img'].to(device)
label = batch['same_domain']
google_images.append(google_img[:1])
yandex_images.append(yandex_img[:1])
labels.append(label[:1].item())
google_images = torch.cat(google_images[:num_samples], dim=0)
yandex_images = torch.cat(yandex_images[:num_samples], dim=0)
# Compute similarity matrix
similarity_matrix = np.zeros((num_samples, num_samples))
with torch.no_grad():
for i in tqdm(range(num_samples), desc='Computing correlations'):
for j in range(num_samples):
google_i = google_images[i:i+1]
yandex_j = yandex_images[j:j+1]
similarity = model(google_i, yandex_j).item()
similarity_matrix[i, j] = similarity
# Plot heatmap
plt.figure(figsize=(14, 12))
im = plt.imshow(similarity_matrix, cmap='RdYlGn', aspect='auto', vmin=0, vmax=1)
plt.colorbar(im, label='Similarity Score', fraction=0.046, pad=0.04)
plt.title('Correlation Heatmap: Google vs Yandex Maps\n(Матрица корреляций снимков)',
fontsize=16, fontweight='bold', pad=20)
plt.xlabel('Yandex Map Index', fontsize=12)
plt.ylabel('Google Map Index', fontsize=12)
# Add grid
plt.grid(True, which='both', color='gray', linestyle='-', linewidth=0.5, alpha=0.3)
# Add text annotations for diagonal (same pairs)
for i in range(min(num_samples, 10)): # Annotate first 10 for readability
if labels[i] == 1: # True match
plt.text(i, i, '', ha='center', va='center',
color='white', fontsize=12, fontweight='bold')
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches='tight')
print(f'Correlation heatmap saved to {save_path}')
plt.show()
# Print statistics
diagonal = np.diag(similarity_matrix)
off_diagonal = similarity_matrix[~np.eye(num_samples, dtype=bool)]
print(f'\n--- Correlation Statistics ---')
print(f'Diagonal (matched pairs):')
print(f' Mean: {diagonal.mean():.4f}')
print(f' Std: {diagonal.std():.4f}')
print(f'\nOff-diagonal (mismatched pairs):')
print(f' Mean: {off_diagonal.mean():.4f}')
print(f' Std: {off_diagonal.std():.4f}')
# =============================================================================
# MAIN TRAINING SCRIPT
# =============================================================================
def main():
"""Основная функция обучения"""
# Configuration
config_dict = config.copy()
if isinstance(config_dict.get('image_size'), list):
config_dict['image_size'] = tuple(config_dict['image_size'])
# Device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'\n{"="*70}')
print(f'Siamese Network Training for Map Correlation')
print(f'Обучение сиамской сети для корреляции снимков')
print(f'{"="*70}')
print(f'Using device: {device}')
if torch.cuda.is_available():
print(f'GPU: {torch.cuda.get_device_name(0)}')
print(f'GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB')
# Create data loaders
print(f'\n{"="*70}')
print('Creating data loaders...')
print(f'{"="*70}')
train_loader, val_loader = create_data_loaders(
root_dir=config_dict['data_dir'],
batch_size=config_dict['batch_size'],
train_split=config_dict['train_split'],
num_workers=config_dict['num_workers'],
image_size=config_dict['image_size'],
augment_train=True,
augment_val=False,
device=device
)
print(f'Train batches: {len(train_loader)}')
print(f'Val batches: {len(val_loader)}')
print(f'Train samples: {len(train_loader.dataset)}')
print(f'Val samples: {len(val_loader.dataset)}')
# Create model
print(f'\n{"="*70}')
print('Creating model...')
print(f'{"="*70}')
model = create_similarity_model(
model_type='backbone',
input_size=config_dict['image_size'][0] if isinstance(config_dict['image_size'], (tuple, list)) else config_dict['image_size'],
input_channels=3,
backbone_name='resnet18',
pretrained=True,
dropout_rate=0.3,
use_batch_norm=True
)
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'Total parameters: {total_params:,}')
print(f'Trainable parameters: {trainable_params:,}')
# Create trainer
trainer = SimilarityTrainer(
model=model,
trainloader=train_loader,
valloader=val_loader,
device=device,
config=config_dict
)
# Train model
print(f'\n{"="*70}')
print('Starting training...')
print(f'{"="*70}')
history = trainer.train(config_dict['epochs'])
# =============================================================================
# VISUALIZATION AND RESULTS
# =============================================================================
print(f'\n{"="*70}')
print('Generating visualizations...')
print(f'{"="*70}')
output_dir = config_dict.get('output_dir', 'runs/similarity')
vis_dir = os.path.join(output_dir, 'visualizations')
os.makedirs(vis_dir, exist_ok=True)
# 1. Training history
print('\n1. Plotting training history...')
plot_training_history(
history,
save_path=os.path.join(vis_dir, 'training_history.png')
)
# 2. Validation metrics
print('\n2. Computing validation predictions...')
trainer.model.eval()
val_predictions = []
val_targets = []
with torch.no_grad():
for batch in tqdm(val_loader, desc='Validation'):
google_img = batch['google_img'].to(device)
yandex_img = batch['yandex_img'].to(device)
target = batch['same_domain'].float().unsqueeze(1)
output = trainer.model(google_img, yandex_img)
val_predictions.append(output.cpu())
val_targets.append(target.cpu())
val_predictions = torch.cat(val_predictions, dim=0)
val_targets = torch.cat(val_targets, dim=0)
# 3. ROC curve
print('\n3. Plotting ROC curve...')
roc_auc = plot_roc_curve(
val_predictions,
val_targets,
save_path=os.path.join(vis_dir, 'roc_curve.png')
)
print(f'ROC AUC Score: {roc_auc:.4f}')
# 4. Confusion matrix
print('\n4. Plotting confusion matrix...')
plot_confusion_matrix(
val_predictions,
val_targets,
threshold=0.5,
save_path=os.path.join(vis_dir, 'confusion_matrix.png')
)
# 5. Similarity distribution
print('\n5. Plotting similarity distribution...')
plot_similarity_distribution(
val_predictions,
val_targets,
save_path=os.path.join(vis_dir, 'similarity_distribution.png')
)
# 6. Sample predictions
print('\n6. Visualizing sample predictions...')
visualize_sample_predictions(
trainer.model,
val_loader.dataset,
device,
num_samples=8,
save_path=os.path.join(vis_dir, 'sample_predictions.png')
)
# 7. Feature space visualization
print('\n7. Visualizing feature space (t-SNE)...')
visualize_feature_space(
trainer.model,
val_loader,
device,
max_samples=500,
save_path=os.path.join(vis_dir, 'feature_space_tsne.png')
)
# 8. Correlation heatmap
print('\n8. Generating correlation heatmap...')
generate_correlation_heatmap(
trainer.model,
val_loader,
device,
num_samples=20,
save_path=os.path.join(vis_dir, 'correlation_heatmap.png')
)
# =============================================================================
# FINAL RESULTS SUMMARY
# =============================================================================
print(f'\n{"="*70}')
print('FINAL RESULTS SUMMARY')
print('ИТОГОВЫЕ РЕЗУЛЬТАТЫ')
print(f'{"="*70}')
print(f'\nBest Validation Loss: {trainer.best_val_loss:.4f}')
print(f'Final Validation Accuracy: {history["val_accuracy"][-1]:.4f}')
print(f'Final Validation F1 Score: {history["val_f1"][-1]:.4f}')
print(f'Final Validation Precision: {history["val_precision"][-1]:.4f}')
print(f'Final Validation Recall: {history["val_recall"][-1]:.4f}')
print(f'ROC AUC Score: {roc_auc:.4f}')
print(f'\nCheckpoints saved to: {os.path.join(output_dir, "checkpoints")}')
print(f'Visualizations saved to: {vis_dir}')
print(f'\n{"="*70}')
print('Training and visualization completed successfully!')
print('Обучение и визуализация завершены успешно!')
print(f'{"="*70}\n')
if __name__ == '__main__':
main()