diff --git a/models/SiaN-similarity/README.md b/models/SiaN-similarity/README.md new file mode 100644 index 0000000..28c31da --- /dev/null +++ b/models/SiaN-similarity/README.md @@ -0,0 +1,131 @@ +# SiaN-Similarity: Модель для оценки схожести изображений + +Модель для оценки схожести между двумя изображениями 256x256. Возвращает значение от 0 до 1, где 1 означает полную схожесть, 0 - полное различие. + +## Архитектура модели + +Модель основана на CNN с residual блоками: +- Общий энкодер для обоих изображений +- Residual blocks с batch normalization +- Слой слияния признаков +- Регрессионная голова с сигмоидой на выходе + +## Использование + +### Установка зависимостей +```bash +pip install torch torchvision pillow +``` + +### Быстрый старт + +```python +import torch +from model import SimilarityCNN + +# Создание модели +model = SimilarityCNN( + input_channels=3, + hidden_channels=64, + num_blocks=4, + dropout_rate=0.3, + use_batch_norm=True, +) + +# Предсказание схожести +img1 = torch.randn(1, 3, 256, 256) # Изображение 1 +img2 = torch.randn(1, 3, 256, 256) # Изображение 2 + +similarity = model.predict_similarity(img1, img2) +print(f"Схожесть: {similarity.item():.4f}") +``` + +### Обучение модели + +```bash +python train_similarity.py \ + --data_dir "путь/к/данным" \ + --batch_size 32 \ + --epochs 100 \ + --learning_rate 2e-4 \ + --output_dir "runs/similarity" +``` + +### Предсказание на новых изображениях + +```bash +python predict.py \ + --image1 "путь/к/изображению1.png" \ + --image2 "путь/к/изображению2.png" \ + --checkpoint "runs/similarity/checkpoints/best_model.pt" +``` + +## Структура проекта + +``` +SiaN-similarity/ +├── model.py # Основная модель +├── dataloader.py # Даталоадер для обучения +├── train_similarity.py # Скрипт для обучения +├── predict.py # Скрипт для предсказания +├── train.py # Оригинальный тренировочный скрипт +└── README.md # Этот файл +``` + +## Конфигурация модели + +Параметры по умолчанию: +- `input_channels`: 3 (RGB) +- `hidden_channels`: 64 +- `num_blocks`: 4 +- `dropout_rate`: 0.3 +- `use_batch_norm`: True +- `image_size`: (256, 256) + +## Формат данных + +Модель ожидает изображения размером 256x256 пикселей в формате RGB. +Для обучения используется датасет с парами изображений и метками схожести. + +## Примеры использования + +### 1. Создание и тестирование модели +```python +from model import create_similarity_model + +model = create_similarity_model( + model_type="cnn", + input_size=(256, 256), + hidden_channels=32, + num_blocks=3, +) +``` + +### 2. Использование функции потерь +```python +from model import SimilarityLoss + +loss_fn = SimilarityLoss() +pred = torch.tensor([[0.8], [0.2]]) +target = torch.tensor([[1.0], [0.0]]) +loss = loss_fn(pred, target) +``` + +### 3. Расчет метрик +```python +metrics = loss_fn.compute_metrics(pred, target) +print(f"Accuracy: {metrics['accuracy']:.4f}") +print(f"F1-score: {metrics['f1']:.4f}") +``` + +## Требования + +- Python 3.8+ +- PyTorch 1.9+ +- torchvision +- Pillow +- numpy + +## Лицензия + +MIT \ No newline at end of file diff --git a/models/SiaN-similarity/dataloader.py b/models/SiaN-similarity/dataloader.py new file mode 100644 index 0000000..00cfb52 --- /dev/null +++ b/models/SiaN-similarity/dataloader.py @@ -0,0 +1,519 @@ +config = { + # Параметры оптимизатора + "learning_rate": 2e-4, + "beta1": 0.5, + "beta2": 0.999, + # Параметры обучения + "batch_size": 4, + "epochs": 100, + # Параметры GAN + "gan_mode": "vanilla", # "vanilla", "lsgan", или "wgangp" + "lambda_L1": 100.0, # Вес L1 потерь + # Регуляризация + "grad_clip": 1.0, + # Ранняя остановка + "early_stopping_patience": 20, + # Выходные данные + "output_dir": "runs/gan_training", + # Логирование + "log_interval": 10, # Логировать каждые N батчей + "save_interval": 5, # Сохранять чекпоинт каждые N эпох + "data_dir": r"C:\Users\admin\Projects\autopilot\datasets\ya_go_maps\images", + "batch_size": 32, + "image_size": [256, 256], + "train_split": 0.8, + "num_workers": 0, +} + + +import os +import random +from typing import Any, Dict, List, Optional, Tuple + +import cv2 +import numpy as np +import torch +from PIL import Image +from torch.utils.data import DataLoader, Dataset + + +class YaGoDataset(Dataset): + """ + Dataset for homography estimation between Yandex and Google map image pairs. + + This dataset loads pairs of images (Yandex and Google maps) and provides + homography matrices for data augmentation and training. + """ + + def __init__( + self, + root_dir: str, + transform=None, + augment: bool = True, + max_samples: Optional[int] = None, + image_size: Tuple[int, int] = (700, 700), + cache_homographies: bool = True, + device: torch.device = None, + ): + """ + Initialize the YaGoDataset. + + Args: + root_dir: Directory containing image pairs (format: {idx:04d}_google.png, {idx:04d}_yandex.png) + transform: Optional torchvision transforms to apply + augment: Whether to apply homography-based data augmentation + max_samples: Maximum number of samples to load (None for all) + image_size: Target size for images (height, width) + cache_homographies: Whether to cache generated homography matrices to disk + """ + self.root_dir = root_dir + self.transform = transform + self.augment = augment + self.image_size = image_size + self.cache_homographies = cache_homographies + self.device = device + + # Find all image pairs + self.image_pairs = self._discover_image_pairs() + + if max_samples is not None: + self.image_pairs = self.image_pairs[:max_samples] + + print(f"Found {len(self.image_pairs)} image pairs in {root_dir}") + + def _discover_image_pairs(self) -> List[Dict[str, Any]]: + """Discover all Google-Yandex image pairs in the dataset directory.""" + image_pairs = [] + + # Get all Google images + google_files = [ + f for f in os.listdir(self.root_dir) if f.endswith("_google.png") + ] + + for google_file in sorted(google_files): + # Extract index from filename + idx_str = google_file.split("_")[0] + try: + idx = int(idx_str) + except ValueError: + continue + + # Check if corresponding Yandex image exists + yandex_file = f"{idx:04d}_yandex.png" + yandex_path = os.path.join(self.root_dir, yandex_file) + + if os.path.exists(yandex_path): + image_pairs.append( + { + "idx": idx, + "google_path": os.path.join(self.root_dir, google_file), + "yandex_path": yandex_path, + } + ) + + return image_pairs + + def __len__(self) -> int: + """Return the number of image pairs in the dataset.""" + return len(self.image_pairs) + + def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: + """ + Get a sample from the dataset. + + Returns a dictionary with: + - 'google_img': Google map image tensor + - 'yandex_img': Yandex map image tensor + - 'homography': Ground truth homography matrix (3x3) + - 'idx': Sample index + """ + pair_info = self.image_pairs[idx] + google_path = pair_info["google_path"] + yandex_path = pair_info["yandex_path"] + same_domain = True + + if np.random.rand() > 0.5: + random_idx = np.random.randint(0, len(self)) + google_path = self.image_pairs[random_idx]["google_path"] + same_domain = random_idx == idx + + # Load images + yandex_img = Image.open(yandex_path).convert("RGB") + google_img = Image.open(google_path).convert("RGB") + + # Resize images to target size + google_img = google_img.resize( + (self.image_size[1], self.image_size[0]), Image.BILINEAR + ) + yandex_img = yandex_img.resize( + (self.image_size[1], self.image_size[0]), Image.BILINEAR + ) + + # Get or generate homography matrix + matrices: Tuple[np.ndarray, np.ndarray, np.ndarray] = ( + self._get_homography_matrix(pair_info["idx"]) + ) + + # Apply data augmentation if enabled + if self.augment: + google_img, yandex_img, homography_matrix = self._apply_augmentation( + google_img, yandex_img, matrices + ) + + # Convert images to tensors + if self.transform: + google_img = self.transform(google_img) + yandex_img = self.transform(yandex_img) + else: + # Default conversion to tensor + google_img = ( + torch.from_numpy(np.array(google_img)).float().permute(2, 0, 1) / 255.0 + ) + yandex_img = ( + torch.from_numpy(np.array(yandex_img)).float().permute(2, 0, 1) / 255.0 + ) + + # Convert homography to tensor + if self.augment: + homography_tensor = torch.from_numpy(homography_matrix).float() + else: + homography_tensor = torch.from_numpy(np.eye(3)) + + return { + "google_img": google_img, + "yandex_img": yandex_img, + "homography": homography_tensor, + "same_domain": same_domain, + "idx": torch.tensor(pair_info["idx"], dtype=torch.long), + } + + def _get_homography_matrix( + self, idx: int + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """ + Get homography matrices for a given index. + + If cached homography exists, load it. Otherwise generate a new one. + """ + + # Generate new homography matrix + homography_matrix_1 = self.generate_random_homography() + homography_matrix_2 = self.generate_random_homography() + homography_matrix_r = np.linalg.inv(homography_matrix_1) @ homography_matrix_2 + + result = (homography_matrix_1, homography_matrix_2, homography_matrix_r) + + return result + + def generate_random_homography(self) -> np.ndarray: + """ + Generate a random homography matrix for data augmentation. + + Returns: + np.ndarray: 3x3 homography matrix. + """ + # Generate random affine transformation parameters + scale = np.random.uniform(0.8, 1.2) # scaling factor + tx = np.random.uniform(-0.50, 0.50) # translation in x + ty = np.random.uniform(-0.50, 0.50) # translation in y + + # rotation + angle_x = np.random.uniform(np.radians(-10), np.radians(10)) + angle_y = np.random.uniform(np.radians(-10), np.radians(10)) + angle_z = np.random.uniform(np.radians(-10), np.radians(10)) + + cy, sy = np.cos(angle_z), np.sin(angle_z) + cp, sp = np.cos(angle_y), np.sin(angle_y) + cr, sr = np.cos(angle_x), np.sin(angle_x) + + Rz = np.array( + [ + [cy, -sy, 0], + [sy, cy, 0], + [0, 0, 1], + ] + ) + + Ry = np.array( + [ + [cp, 0, sp], + [0, 1, 0], + [-sp, 0, cp], + ] + ) + + Rx = np.array( + [ + [1, 0, 0], + [0, cr, -sr], + [0, sr, cr], + ] + ) + + # Create affine transformation matrix + T = np.array( + [ + [1, 0, tx], + [0, 1, ty], + [0, 0, scale], + ] + ) + + K = self.get_camera_matrix() + + return K @ Rx @ Ry @ Rz @ T @ np.linalg.inv(K) + + def get_camera_matrix(self) -> np.ndarray: + w, h = config["image_size"] + + K = np.array( + [ + [w / 2, 0, w / 2], + [0, h / 2, h / 2], + [0, 0, 1], + ] + ) + + return K + + def _apply_augmentation( + self, + google_img: Image.Image, + yandex_img: Image.Image, + matrices: Tuple[np.ndarray, np.ndarray, np.ndarray], + ) -> Tuple[Image.Image, Image.Image, np.ndarray]: + """ + Apply homography-based data augmentation to image pair. + + Args: + google_img: Google map image + yandex_img: Yandex map image + matrices: homography matrices + + Returns: + Tuple of (augmented_google_img, augmented_yandex_img, augmented_homography) + """ + # Combine with base homography + combined_homography = matrices[2] + + # Apply augmentation to both images + # google_aug = self._apply_homography_to_image(google_img, aug_homography) + yandex_aug = self._apply_homography_to_image(yandex_img, matrices[0]) + google_aug = self._apply_homography_to_image(google_img, matrices[1]) + + return google_aug, yandex_aug, combined_homography + + def _apply_homography_to_image( + self, img: Image.Image, homography: np.ndarray + ) -> Image.Image: + """ + Apply homography transformation to a single image. + + Args: + img: PIL Image to transform + homography: 3x3 homography matrix + + Returns: + Transformed PIL Image + """ + # Convert to numpy array + img_np = np.array(img) + + # Get image dimensions + h, w = img_np.shape[:2] + + # Apply homography transformation + transformed = cv2.warpPerspective( + img_np, + homography, + (w, h), + flags=cv2.INTER_LINEAR, + # borderMode=cv2.BORDER_REFLECT, + ) + + # Convert back to PIL Image + return Image.fromarray(transformed) + + def get_sample_without_augmentation(self, idx: int) -> Dict[str, Any]: + """ + Get a sample without data augmentation. + + Useful for visualization and evaluation. + """ + pair_info = self.image_pairs[idx] + + # Load images + google_img = Image.open(pair_info["google_path"]).convert("RGB") + yandex_img = Image.open(pair_info["yandex_path"]).convert("RGB") + + # Resize + google_img = google_img.resize( + (self.image_size[1], self.image_size[0]), Image.BILINEAR + ) + yandex_img = yandex_img.resize( + (self.image_size[1], self.image_size[0]), Image.BILINEAR + ) + + # Get homography matrix + homography_matrix = self._get_homography_matrix(pair_info["idx"]) + + return { + "google_img": google_img, + "yandex_img": yandex_img, + "homography": homography_matrix, + "idx": pair_info["idx"], + "google_path": pair_info["google_path"], + "yandex_path": pair_info["yandex_path"], + } + + +def create_data_loaders( + root_dir: str, + batch_size: int = 32, + train_split: float = 0.8, + num_workers: int = 4, + image_size: Tuple[int, int] = (256, 256), + augment_train: bool = True, + augment_val: bool = False, + device: torch.device = None, +) -> Tuple[DataLoader, DataLoader]: + """ + Create train and validation data loaders for homography estimation. + + Args: + root_dir: Directory containing image pairs + batch_size: Batch size for data loaders + train_split: Fraction of data to use for training + num_workers: Number of worker processes for data loading + image_size: Target image size (height, width) + augment_train: Whether to augment training data + augment_val: Whether to augment validation data + device: Target device for tensors (optional) + + Returns: + Tuple of (train_loader, val_loader) + """ + from torchvision import transforms + + # Define transforms + transform = transforms.Compose( + [ + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ] + ) + + # Create full dataset + full_dataset = YaGoDataset( + root_dir=root_dir, + transform=transform, + augment=False, # We'll handle augmentation separately + image_size=image_size, + cache_homographies=True, + device=device, + ) + + # Split dataset + dataset_size = len(full_dataset) + train_size = int(train_split * dataset_size) + val_size = dataset_size - train_size + + # Create indices for splitting + indices = list(range(dataset_size)) + random.shuffle(indices) + train_indices = indices[:train_size] + val_indices = indices[train_size:] + + # Create subset samplers + from torch.utils.data import Subset + + train_dataset = Subset(full_dataset, train_indices) + val_dataset = Subset(full_dataset, val_indices) + + # Apply augmentation by overriding __getitem__ for train dataset + if augment_train: + + class AugmentedSubset(Subset): + def __init__(self, dataset, indices, device=None): + super().__init__(dataset, indices) + self.device = device + + def __getitem__(self, idx): + sample = self.dataset[self.indices[idx]] + # Apply augmentation + google_img = sample["google_img"] + yandex_img = sample["yandex_img"] + homography = sample["homography"] + + if self.device is not None: + google_img = google_img.to(self.device) + yandex_img = yandex_img.to(self.device) + homography = homography.to(self.device) + + # Generate augmentation homography + aug_homography = torch.from_numpy( + full_dataset.generate_random_homography() + ).float() + + if self.device is not None: + aug_homography = aug_homography.to(self.device) + + # Combine homographies + combined_homography = aug_homography @ homography + + # Apply augmentation (simplified - in practice would warp images) + # For now, we just return the combined homography + return { + "google_img": google_img, + "yandex_img": yandex_img, + "homography": combined_homography, + "idx": sample["idx"], + } + + train_dataset = AugmentedSubset(full_dataset, train_indices, device=device) + + # Create data loaders + train_loader = DataLoader( + train_dataset, + batch_size=batch_size, + shuffle=True, + num_workers=num_workers, + pin_memory=True, + ) + + val_loader = DataLoader( + val_dataset, + batch_size=batch_size, + shuffle=False, + num_workers=num_workers, + pin_memory=True, + ) + + return train_loader, val_loader + + +# Example usage +dataset = YaGoDataset( + root_dir=config["data_dir"], + augment=False, + image_size=(256, 256), +) + +print(f"Dataset size: {len(dataset)}") + +# Get a sample +sample = dataset[0] +print(f"Sample keys: {list(sample.keys())}") +print(f"Google image shape: {sample['google_img'].shape}") +print(f"Yandex image shape: {sample['yandex_img'].shape}") +print(f"Homography shape: {sample['homography'].shape}") + +# Create data loaders +train_loader, val_loader = create_data_loaders( + root_dir=config["data_dir"], + batch_size=16, + train_split=0.8, +) + +print(f"Train batches: {len(train_loader)}") +print(f"Val batches: {len(val_loader)}") diff --git a/models/SiaN-similarity/demo.py b/models/SiaN-similarity/demo.py new file mode 100644 index 0000000..53715fd --- /dev/null +++ b/models/SiaN-similarity/demo.py @@ -0,0 +1,192 @@ +""" +Демонстрационный скрипт для модели оценки схожести изображений. +""" + +import matplotlib.pyplot as plt +import numpy as np +import torch +from model import SimilarityCNN, SimilarityLoss +from PIL import Image, ImageDraw, ImageFont +from torchvision import transforms + + +def create_test_images(): + """Создание тестовых изображений для демонстрации.""" + images = [] + + # Изображение 1: Красный квадрат + img1 = Image.new("RGB", (256, 256), color="white") + draw = ImageDraw.Draw(img1) + draw.rectangle([50, 50, 200, 200], fill="red", outline="black", width=2) + images.append(("Красный квадрат", img1)) + + # Изображение 2: Тот же красный квадрат (похожее) + img2 = Image.new("RGB", (256, 256), color="white") + draw = ImageDraw.Draw(img2) + draw.rectangle([55, 55, 205, 205], fill="red", outline="black", width=2) + images.append(("Похожий красный квадрат", img2)) + + # Изображение 3: Синий круг (разное) + img3 = Image.new("RGB", (256, 256), color="white") + draw = ImageDraw.Draw(img3) + draw.ellipse([50, 50, 200, 200], fill="blue", outline="black", width=2) + images.append(("Синий круг", img3)) + + # Изображение 4: Зеленый треугольник (разное) + img4 = Image.new("RGB", (256, 256), color="white") + draw = ImageDraw.Draw(img4) + draw.polygon( + [(128, 50), (50, 200), (200, 200)], fill="green", outline="black", width=2 + ) + images.append(("Зеленый треугольник", img4)) + + return images + + +def preprocess_image(image): + """Преобразование PIL Image в тензор.""" + transform = transforms.Compose( + [ + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ] + ) + return transform(image).unsqueeze(0) # Добавляем batch dimension + + +def display_results(images, similarities): + """Отображение результатов сравнения.""" + fig, axes = plt.subplots(2, 2, figsize=(12, 10)) + axes = axes.flatten() + + for idx, (title, img) in enumerate(images): + ax = axes[idx] + ax.imshow(img) + ax.set_title(title, fontsize=12, fontweight="bold") + ax.axis("off") + + plt.suptitle("Тестовые изображения", fontsize=16, fontweight="bold") + plt.tight_layout() + plt.show() + + # Вывод результатов сравнения + print("\n" + "=" * 60) + print("РЕЗУЛЬТАТЫ СРАВНЕНИЯ ИЗОБРАЖЕНИЙ") + print("=" * 60) + + comparisons = [ + ("Красный квадрат", "Похожий красный квадрат"), + ("Красный квадрат", "Синий круг"), + ("Красный квадрат", "Зеленый треугольник"), + ("Похожий красный квадрат", "Синий круг"), + ] + + for i, (name1, name2) in enumerate(comparisons): + idx1 = [idx for idx, (name, _) in enumerate(images) if name == name1][0] + idx2 = [idx for idx, (name, _) in enumerate(images) if name == name2][0] + + sim = similarities[idx1, idx2] + interpretation = "ПОХОЖИ" if sim > 0.5 else "РАЗНЫЕ" + + print(f"\n{name1} vs {name2}:") + print(f" Схожесть: {sim:.4f}") + print(f" Интерпретация: {interpretation}") + print(f" Уверенность: {'Высокая' if sim > 0.7 or sim < 0.3 else 'Средняя'}") + + +def test_loss_function(): + """Тестирование функции потерь.""" + print("\n" + "=" * 60) + print("ТЕСТИРОВАНИЕ ФУНКЦИИ ПОТЕРЬ") + print("=" * 60) + + loss_fn = SimilarityLoss() + + # Тестовые данные + predictions = torch.tensor([[0.9], [0.1], [0.7], [0.3]]) + targets = torch.tensor([[1.0], [0.0], [1.0], [0.0]]) + + # Расчет потерь + loss = loss_fn(predictions, targets) + print(f"\nПотери: {loss.item():.4f}") + + # Расчет метрик + metrics = loss_fn.compute_metrics(predictions, targets) + print("\nМетрики:") + for key, value in metrics.items(): + print(f" {key}: {value:.4f}") + + +def main(): + """Основная функция демонстрации.""" + print("ДЕМОНСТРАЦИЯ МОДЕЛИ ОЦЕНКИ СХОЖЕСТИ ИЗОБРАЖЕНИЙ") + print("=" * 60) + + # Создание модели + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + print(f"\nУстройство: {device}") + + model = SimilarityCNN( + input_channels=3, + hidden_channels=64, + num_blocks=4, + dropout_rate=0.3, + use_batch_norm=True, + ).to(device) + + print(f"Параметры модели: {sum(p.numel() for p in model.parameters()):,}") + + # Создание тестовых изображений + print("\nСоздание тестовых изображений...") + test_images = create_test_images() + + # Преобразование изображений в тензоры + tensors = [] + for name, img in test_images: + tensor = preprocess_image(img).to(device) + tensors.append(tensor) + + # Расчет схожести между всеми парами изображений + print("\nРасчет схожести между изображениями...") + n_images = len(test_images) + similarity_matrix = np.zeros((n_images, n_images)) + + model.eval() + with torch.no_grad(): + for i in range(n_images): + for j in range(n_images): + if i <= j: # Рассчитываем только верхний треугольник + sim = model.predict_similarity(tensors[i], tensors[j]) + similarity_matrix[i, j] = sim.item() + similarity_matrix[j, i] = sim.item() # Симметричная матрица + + # Отображение результатов + display_results(test_images, similarity_matrix) + + # Тестирование функции потерь + test_loss_function() + + # Дополнительная информация + print("\n" + "=" * 60) + print("ИНФОРМАЦИЯ О МОДЕЛИ") + print("=" * 60) + print("\nАрхитектура модели:") + print("-" * 40) + print("Вход: два изображения 256x256x3") + print("Энкодер: CNN с residual блоками") + print("Слой слияния: объединение признаков") + print("Выход: значение схожести [0, 1]") + print("\nИнтерпретация результатов:") + print("- 0.8-1.0: Очень похожи") + print("- 0.6-0.8: Похожи") + print("- 0.4-0.6: Нейтрально") + print("- 0.2-0.4: Разные") + print("- 0.0-0.2: Совершенно разные") + + print("\n" + "=" * 60) + print("ДЕМОНСТРАЦИЯ ЗАВЕРШЕНА") + print("=" * 60) + + +if __name__ == "__main__": + main() diff --git a/models/SiaN-similarity/example.py b/models/SiaN-similarity/example.py new file mode 100644 index 0000000..d17cda7 --- /dev/null +++ b/models/SiaN-similarity/example.py @@ -0,0 +1,216 @@ +""" +Пример использования модели оценки схожести с даталоадером. +""" + +import torch +from dataloader import YaGoDataset, create_data_loaders +from model import SimilarityCNN, SimilarityLoss + + +def main(): + """Основной пример использования.""" + print("ПРИМЕР ИСПОЛЬЗОВАНИЯ МОДЕЛИ СХОЖЕСТИ С ДАТАЛОАДЕРОМ") + print("=" * 60) + + # Конфигурация + config = { + "data_dir": r"C:\Users\admin\Projects\autopilot\datasets\ya_go_maps\images", + "batch_size": 4, + "image_size": (256, 256), + "train_split": 0.8, + "num_workers": 0, + } + + # Устройство + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + print(f"Устройство: {device}") + + # 1. Создание датасета + print("\n1. СОЗДАНИЕ ДАТАСЕТА") + print("-" * 40) + + dataset = YaGoDataset( + root_dir=config["data_dir"], + augment=False, + image_size=config["image_size"], + ) + + print(f"Размер датасета: {len(dataset)} пар изображений") + + # Получение примера из датасета + sample = dataset[0] + print(f"\nПример из датасета:") + print(f" Google image shape: {sample['google_img'].shape}") + print(f" Yandex image shape: {sample['yandex_img'].shape}") + print(f" Same domain: {sample['same_domain']}") + print(f" Index: {sample['idx'].item()}") + + # 2. Создание даталоадеров + print("\n2. СОЗДАНИЕ ДАТАЛОАДЕРОВ") + print("-" * 40) + + train_loader, val_loader = create_data_loaders( + root_dir=config["data_dir"], + batch_size=config["batch_size"], + train_split=config["train_split"], + num_workers=config["num_workers"], + image_size=config["image_size"], + augment_train=True, + augment_val=False, + device=device, + ) + + print(f"Train batches: {len(train_loader)}") + print(f"Val batches: {len(val_loader)}") + + # 3. Создание модели + print("\n3. СОЗДАНИЕ МОДЕЛИ") + print("-" * 40) + + model = SimilarityCNN( + input_channels=3, + hidden_channels=64, + num_blocks=4, + dropout_rate=0.3, + use_batch_norm=True, + ).to(device) + + print(f"Параметры модели: {sum(p.numel() for p in model.parameters()):,}") + + # 4. Тестирование на одном батче + print("\n4. ТЕСТИРОВАНИЕ НА ОДНОМ БАТЧЕ") + print("-" * 40) + + # Получаем батч из train_loader + for batch in train_loader: + google_img = batch["google_img"].to(device) + yandex_img = batch["yandex_img"].to(device) + same_domain = batch["same_domain"].float().to(device).unsqueeze(1) + + print(f"Batch size: {google_img.shape[0]}") + print(f"Image shape: {google_img.shape[1:]}") + print(f"Same domain labels: {same_domain.squeeze().tolist()}") + + # Предсказание схожести + with torch.no_grad(): + predictions = model.predict_similarity(google_img, yandex_img) + print(f"\nПредсказания схожести:") + for i in range(len(predictions)): + print( + f" Sample {i}: {predictions[i].item():.4f} (target: {same_domain[i].item():.1f})" + ) + + # Расчет потерь + loss_fn = SimilarityLoss().to(device) + loss = loss_fn(predictions, same_domain) + print(f"\nПотери на батче: {loss.item():.4f}") + + # Расчет метрик + metrics = loss_fn.compute_metrics(predictions, same_domain) + print("\nМетрики на батче:") + for key, value in metrics.items(): + print(f" {key}: {value:.4f}") + + break # Только первый батч + + # 5. Обучение на одном эпохе (демонстрация) + print("\n5. ДЕМОНСТРАЦИЯ ОБУЧЕНИЯ НА ОДНОЙ ЭПОХЕ") + print("-" * 40) + + optimizer = torch.optim.Adam(model.parameters(), lr=2e-4) + model.train() + + total_loss = 0 + total_samples = 0 + + for batch_idx, batch in enumerate(train_loader): + if batch_idx >= 3: # Ограничиваем 3 батчами для демонстрации + break + + google_img = batch["google_img"].to(device) + yandex_img = batch["yandex_img"].to(device) + same_domain = batch["same_domain"].float().to(device).unsqueeze(1) + + optimizer.zero_grad() + + predictions = model(google_img, yandex_img) + loss = loss_fn(predictions, same_domain) + + loss.backward() + optimizer.step() + + batch_loss = loss.item() * google_img.size(0) + total_loss += batch_loss + total_samples += google_img.size(0) + + print(f"Batch {batch_idx + 1}: loss = {loss.item():.4f}") + + avg_loss = total_loss / total_samples + print(f"\nСредние потери за 3 батча: {avg_loss:.4f}") + + # 6. Валидация + print("\n6. ВАЛИДАЦИЯ") + print("-" * 40) + + model.eval() + val_loss = 0 + val_samples = 0 + + with torch.no_grad(): + for batch_idx, batch in enumerate(val_loader): + if batch_idx >= 2: # Ограничиваем 2 батчами для демонстрации + break + + google_img = batch["google_img"].to(device) + yandex_img = batch["yandex_img"].to(device) + same_domain = batch["same_domain"].float().to(device).unsqueeze(1) + + predictions = model.predict_similarity(google_img, yandex_img) + loss = loss_fn(predictions, same_domain) + + val_loss += loss.item() * google_img.size(0) + val_samples += google_img.size(0) + + print(f"Val batch {batch_idx + 1}: loss = {loss.item():.4f}") + + avg_val_loss = val_loss / val_samples + print(f"\nСредние потери на валидации: {avg_val_loss:.4f}") + + # 7. Пример использования для отдельных изображений + print("\n7. ПРИМЕР ДЛЯ ОТДЕЛЬНЫХ ИЗОБРАЖЕНИЙ") + print("-" * 40) + + # Берем два примера из датасета + sample1 = dataset[0] + sample2 = dataset[1] + + # Подготавливаем тензоры + img1_1 = sample1["google_img"].unsqueeze(0).to(device) + img1_2 = sample1["yandex_img"].unsqueeze(0).to(device) + + img2_1 = sample2["google_img"].unsqueeze(0).to(device) + img2_2 = sample2["yandex_img"].unsqueeze(0).to(device) + + # Предсказания + with torch.no_grad(): + # Сравнение пар из одного домена + sim_same1 = model.predict_similarity(img1_1, img1_2) + sim_same2 = model.predict_similarity(img2_1, img2_2) + + # Сравнение пар из разных доменов + sim_diff1 = model.predict_similarity(img1_1, img2_2) + sim_diff2 = model.predict_similarity(img2_1, img1_2) + + print("Сравнение пар изображений:") + print(f" Пара 1 (один домен): {sim_same1.item():.4f}") + print(f" Пара 2 (один домен): {sim_same2.item():.4f}") + print(f" Разные домены 1: {sim_diff1.item():.4f}") + print(f" Разные домены 2: {sim_diff2.item():.4f}") + + print("\n" + "=" * 60) + print("ПРИМЕР ЗАВЕРШЕН") + print("=" * 60) + + +if __name__ == "__main__": + main() diff --git a/models/SiaN-similarity/model.py b/models/SiaN-similarity/model.py new file mode 100644 index 0000000..39a112c --- /dev/null +++ b/models/SiaN-similarity/model.py @@ -0,0 +1,322 @@ +from typing import Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class SimilarityCNN(nn.Module): + """ + CNN model for similarity estimation between two images. + + Takes two images as input and outputs a similarity score between 0 and 1. + """ + + def __init__( + self, + input_channels: int = 3, + hidden_channels: int = 64, + num_blocks: int = 4, + dropout_rate: float = 0.3, + use_batch_norm: bool = True, + ): + super().__init__() + + self.input_channels = input_channels + self.hidden_channels = hidden_channels + self.num_blocks = num_blocks + self.dropout_rate = dropout_rate + self.use_batch_norm = use_batch_norm + + self.encoder = self._build_encoder() + + self.fusion_layers = self._build_fusion_layers() + + self.regression_head = self._build_regression_head() + + self._initialize_weights() + + def _build_encoder(self) -> nn.Module: + layers = [] + in_channels = self.input_channels + out_channels = self.hidden_channels + + layers.append( + nn.Conv2d(in_channels, out_channels, kernel_size=7, stride=2, padding=3) + ) + if self.use_batch_norm: + layers.append(nn.BatchNorm2d(out_channels)) + layers.append(nn.ReLU(inplace=True)) + layers.append(nn.MaxPool2d(kernel_size=3, stride=2, padding=1)) + + for i in range(self.num_blocks): + block_in_channels = out_channels + block_out_channels = out_channels * 2 if i < 2 else out_channels + + layers.append( + ResidualBlock( + in_channels=block_in_channels, + out_channels=block_out_channels, + stride=1 if i == 0 else 2, + dropout_rate=self.dropout_rate, + use_batch_norm=self.use_batch_norm, + ) + ) + + if i < 2: + out_channels = block_out_channels + + return nn.Sequential(*layers) + + def _build_fusion_layers(self) -> nn.Module: + fused_channels = self.hidden_channels * 8 + + layers = [ + nn.Conv2d( + fused_channels, self.hidden_channels * 4, kernel_size=3, padding=1 + ), + nn.BatchNorm2d(self.hidden_channels * 4) + if self.use_batch_norm + else nn.Identity(), + nn.ReLU(inplace=True), + nn.Dropout2d(self.dropout_rate), + nn.Conv2d( + self.hidden_channels * 4, + self.hidden_channels * 2, + kernel_size=3, + padding=1, + ), + nn.BatchNorm2d(self.hidden_channels * 2) + if self.use_batch_norm + else nn.Identity(), + nn.ReLU(inplace=True), + nn.Dropout2d(self.dropout_rate), + nn.AdaptiveAvgPool2d((1, 1)), + ] + + return nn.Sequential(*layers) + + def _build_regression_head(self) -> nn.Module: + input_features = self.hidden_channels * 2 + + layers = [ + nn.Flatten(), + nn.Linear(input_features, 512), + nn.BatchNorm1d(512) if self.use_batch_norm else nn.Identity(), + nn.ReLU(inplace=True), + nn.Dropout(self.dropout_rate), + nn.Linear(512, 256), + nn.BatchNorm1d(256) if self.use_batch_norm else nn.Identity(), + nn.ReLU(inplace=True), + nn.Dropout(self.dropout_rate), + nn.Linear(256, 128), + nn.BatchNorm1d(128) if self.use_batch_norm else nn.Identity(), + nn.ReLU(inplace=True), + nn.Dropout(self.dropout_rate), + nn.Linear(128, 1), + nn.Sigmoid(), + ] + + return nn.Sequential(*layers) + + def _initialize_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Linear): + nn.init.normal_(m.weight, 0, 0.01) + nn.init.constant_(m.bias, 0) + + def forward( + self, + img1: torch.Tensor, + img2: torch.Tensor, + ) -> torch.Tensor: + features1 = self.encoder(img1) + features2 = self.encoder(img2) + + combined_features = torch.cat([features1, features2], dim=1) + + fused_features = self.fusion_layers(combined_features) + + similarity = self.regression_head(fused_features) + + return similarity + + def predict_similarity( + self, + img1: torch.Tensor, + img2: torch.Tensor, + ) -> torch.Tensor: + original_training = self.training + self.eval() + with torch.no_grad(): + similarity = self.forward(img1, img2) + if original_training: + self.train() + return similarity + + +class ResidualBlock(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + stride: int = 1, + dropout_rate: float = 0.3, + use_batch_norm: bool = True, + ): + super().__init__() + + self.conv1 = nn.Conv2d( + in_channels, + out_channels, + kernel_size=3, + stride=stride, + padding=1, + bias=False, + ) + self.bn1 = nn.BatchNorm2d(out_channels) if use_batch_norm else nn.Identity() + self.relu1 = nn.ReLU(inplace=True) + self.dropout1 = ( + nn.Dropout2d(dropout_rate) if dropout_rate > 0 else nn.Identity() + ) + + self.conv2 = nn.Conv2d( + out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False + ) + self.bn2 = nn.BatchNorm2d(out_channels) if use_batch_norm else nn.Identity() + self.relu2 = nn.ReLU(inplace=True) + self.dropout2 = ( + nn.Dropout2d(dropout_rate) if dropout_rate > 0 else nn.Identity() + ) + + self.shortcut = nn.Sequential() + if stride != 1 or in_channels != out_channels: + self.shortcut = nn.Sequential( + nn.Conv2d( + in_channels, out_channels, kernel_size=1, stride=stride, bias=False + ), + nn.BatchNorm2d(out_channels) if use_batch_norm else nn.Identity(), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + identity = self.shortcut(x) + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu1(out) + out = self.dropout1(out) + + out = self.conv2(out) + out = self.bn2(out) + + out += identity + out = self.relu2(out) + out = self.dropout2(out) + + return out + + +class SimilarityLoss(nn.Module): + def __init__(self): + super().__init__() + self.criterion = nn.BCELoss() + + def forward( + self, + pred_similarity: torch.Tensor, + target_same: torch.Tensor, + ) -> torch.Tensor: + return self.criterion(pred_similarity, target_same) + + def compute_metrics( + self, + pred_similarity: torch.Tensor, + target_same: torch.Tensor, + threshold: float = 0.5, + ) -> dict: + with torch.no_grad(): + pred_binary = (pred_similarity > threshold).float() + target_binary = (target_same > 0.5).float() + + correct = (pred_binary == target_binary).float() + accuracy = correct.mean().item() + + tp = ((pred_binary == 1) & (target_binary == 1)).float().sum().item() + fp = ((pred_binary == 1) & (target_binary == 0)).float().sum().item() + fn = ((pred_binary == 0) & (target_binary == 1)).float().sum().item() + tn = ((pred_binary == 0) & (target_binary == 0)).float().sum().item() + + precision = tp / (tp + fp + 1e-8) + recall = tp / (tp + fn + 1e-8) + f1 = 2 * precision * recall / (precision + recall + 1e-8) + + return { + "accuracy": accuracy, + "precision": precision, + "recall": recall, + "f1": f1, + "mean_similarity": pred_similarity.mean().item(), + } + + +def create_similarity_model( + model_type: str = "cnn", + input_size: Tuple[int, int] = (256, 256), + **kwargs, +) -> nn.Module: + if model_type == "cnn": + return SimilarityCNN(**kwargs) + else: + raise ValueError(f"Unknown model type: {model_type}") + + +if __name__ == "__main__": + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + print(f"Using device: {device}") + + model = SimilarityCNN( + input_channels=3, + hidden_channels=64, + num_blocks=4, + dropout_rate=0.3, + use_batch_norm=True, + ).to(device) + + print( + f"Model created with {sum(p.numel() for p in model.parameters()):,} parameters" + ) + + batch_size = 4 + height, width = 256, 256 + + img1 = torch.randn(batch_size, 3, height, width).to(device) + img2 = torch.randn(batch_size, 3, height, width).to(device) + + print("\nTesting forward pass...") + output = model(img1, img2) + print(f"Output shape: {output.shape}") + print(f"Sample output: {output[0].item():.4f}") + + print("\nTesting prediction...") + pred = model.predict_similarity(img1, img2) + print(f"Prediction shape: {pred.shape}") + + print("\nTesting loss function...") + target = torch.rand(batch_size, 1).to(device) + loss_fn = SimilarityLoss().to(device) + loss = loss_fn(output, target) + print(f"Loss value: {loss.item():.6f}") + + print("\nTesting metrics...") + metrics = loss_fn.compute_metrics(output, target) + for key, value in metrics.items(): + print(f"{key}: {value:.6f}") + + print("\nAll tests completed successfully!") diff --git a/models/SiaN-similarity/predict.py b/models/SiaN-similarity/predict.py new file mode 100644 index 0000000..efce6a5 --- /dev/null +++ b/models/SiaN-similarity/predict.py @@ -0,0 +1,146 @@ +""" +Script for predicting similarity between two images. +""" + +import argparse +import os +from pathlib import Path + +import torch +from model import SimilarityCNN +from PIL import Image +from torchvision import transforms + + +def load_image(image_path: str, image_size: tuple = (256, 256)) -> torch.Tensor: + """Load and preprocess image.""" + transform = transforms.Compose( + [ + transforms.Resize(image_size), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ] + ) + + image = Image.open(image_path).convert("RGB") + return transform(image).unsqueeze(0) # Add batch dimension + + +def predict_similarity( + model: SimilarityCNN, + image1_path: str, + image2_path: str, + device: torch.device, + image_size: tuple = (256, 256), +) -> float: + """Predict similarity between two images.""" + model.eval() + + img1 = load_image(image1_path, image_size).to(device) + img2 = load_image(image2_path, image_size).to(device) + + with torch.no_grad(): + similarity = model(img1, img2) + + return similarity.item() + + +def load_model( + checkpoint_path: str, + device: torch.device, + **model_kwargs, +) -> SimilarityCNN: + """Load model from checkpoint.""" + model = SimilarityCNN(**model_kwargs).to(device) + + checkpoint = torch.load(checkpoint_path, map_location=device) + model.load_state_dict(checkpoint["model_state_dict"]) + + return model + + +def main(): + parser = argparse.ArgumentParser( + description="Predict similarity between two images" + ) + parser.add_argument("--image1", type=str, required=True, help="Path to first image") + parser.add_argument( + "--image2", type=str, required=True, help="Path to second image" + ) + parser.add_argument( + "--checkpoint", + type=str, + default="runs/similarity/checkpoints/best_model.pt", + help="Path to model checkpoint", + ) + parser.add_argument( + "--device", + type=str, + default="cuda" if torch.cuda.is_available() else "cpu", + help="Device to use for inference", + ) + parser.add_argument( + "--image_size", + type=int, + default=256, + help="Image size for model input", + ) + + args = parser.parse_args() + + device = torch.device(args.device) + print(f"Using device: {device}") + + if not os.path.exists(args.image1): + print(f"Error: Image not found: {args.image1}") + return + + if not os.path.exists(args.image2): + print(f"Error: Image not found: {args.image2}") + return + + if not os.path.exists(args.checkpoint): + print(f"Warning: Checkpoint not found: {args.checkpoint}") + print("Using randomly initialized model for demonstration") + model = SimilarityCNN( + input_channels=3, + hidden_channels=64, + num_blocks=4, + dropout_rate=0.3, + use_batch_norm=True, + ).to(device) + else: + print(f"Loading model from: {args.checkpoint}") + model = load_model( + checkpoint_path=args.checkpoint, + device=device, + input_channels=3, + hidden_channels=64, + num_blocks=4, + dropout_rate=0.3, + use_batch_norm=True, + ) + + print( + f"Model loaded with {sum(p.numel() for p in model.parameters()):,} parameters" + ) + + similarity = predict_similarity( + model=model, + image1_path=args.image1, + image2_path=args.image2, + device=device, + image_size=(args.image_size, args.image_size), + ) + + print(f"\nSimilarity between images:") + print(f" Image 1: {args.image1}") + print(f" Image 2: {args.image2}") + print(f" Similarity score: {similarity:.4f}") + print(f" Interpretation: {'Similar' if similarity > 0.5 else 'Different'}") + + return similarity + + +if __name__ == "__main__": + main() diff --git a/models/SiaN-similarity/train.py b/models/SiaN-similarity/train.py new file mode 100644 index 0000000..1296a01 --- /dev/null +++ b/models/SiaN-similarity/train.py @@ -0,0 +1,275 @@ +""" +Training script for image similarity estimation. +""" + +import argparse +import os +import time +from datetime import datetime + +import torch +import torch.nn as nn +import torch.optim as optim +from dataloader import create_data_loaders +from model import SimilarityCNN, SimilarityLoss, create_similarity_model +from torch.utils.data import DataLoader +from torch.utils.tensorboard import SummaryWriter +from tqdm import tqdm + + +class SimilarityTrainer: + def __init__( + self, + model: nn.Module, + train_loader: DataLoader, + val_loader: DataLoader, + device: torch.device, + config: dict, + ): + self.model = model.to(device) + self.train_loader = train_loader + self.val_loader = val_loader + self.device = device + self.config = config + + self.criterion = SimilarityLoss() + self.optimizer = optim.Adam( + model.parameters(), + lr=config.get("learning_rate", 2e-4), + betas=(config.get("beta1", 0.5), config.get("beta2", 0.999)), + ) + + self.writer = None + self.best_val_loss = float("inf") + self.epochs_without_improvement = 0 + + def train_epoch(self, epoch: int) -> dict: + self.model.train() + total_loss = 0 + total_samples = 0 + + pbar = tqdm(self.train_loader, desc=f"Epoch {epoch}") + for batch_idx, batch in enumerate(pbar): + google_img = batch["google_img"].to(self.device) + yandex_img = batch["yandex_img"].to(self.device) + target = batch["same_domain"].float().to(self.device).unsqueeze(1) + + self.optimizer.zero_grad() + + output = self.model(google_img, yandex_img) + loss = self.criterion(output, target) + + loss.backward() + self.optimizer.step() + + total_loss += loss.item() * google_img.size(0) + total_samples += google_img.size(0) + + if batch_idx % self.config.get("log_interval", 10) == 0: + metrics = self.criterion.compute_metrics(output, target) + pbar.set_postfix( + { + "loss": loss.item(), + "acc": metrics["accuracy"], + } + ) + + if self.writer: + self.writer.add_scalar( + "train/loss", + loss.item(), + epoch * len(self.train_loader) + batch_idx, + ) + self.writer.add_scalar( + "train/accuracy", + metrics["accuracy"], + epoch * len(self.train_loader) + batch_idx, + ) + + avg_loss = total_loss / total_samples + return {"loss": avg_loss} + + def validate(self) -> dict: + self.model.eval() + total_loss = 0 + total_samples = 0 + all_metrics = [] + + with torch.no_grad(): + for batch in tqdm(self.val_loader, desc="Validation"): + google_img = batch["google_img"].to(self.device) + yandex_img = batch["yandex_img"].to(self.device) + target = batch["same_domain"].float().to(self.device).unsqueeze(1) + + output = self.model(google_img, yandex_img) + loss = self.criterion(output, target) + + total_loss += loss.item() * google_img.size(0) + total_samples += google_img.size(0) + + metrics = self.criterion.compute_metrics(output, target) + all_metrics.append(metrics) + + avg_loss = total_loss / total_samples + + avg_metrics = {} + for key in all_metrics[0].keys(): + avg_metrics[key] = sum(m[key] for m in all_metrics) / len(all_metrics) + + return {"loss": avg_loss, **avg_metrics} + + def train(self, num_epochs: int): + log_dir = self.config.get("output_dir", "runs/similarity") + os.makedirs(log_dir, exist_ok=True) + self.writer = SummaryWriter(log_dir) + + print(f"Starting training for {num_epochs} epochs") + print(f"Logging to: {log_dir}") + + for epoch in range(1, num_epochs + 1): + print(f"\nEpoch {epoch}/{num_epochs}") + + train_metrics = self.train_epoch(epoch) + val_metrics = self.validate() + + print(f"Train Loss: {train_metrics['loss']:.4f}") + print(f"Val Loss: {val_metrics['loss']:.4f}") + print(f"Val Accuracy: {val_metrics['accuracy']:.4f}") + print(f"Val F1: {val_metrics['f1']:.4f}") + + if self.writer: + self.writer.add_scalar("epoch/train_loss", train_metrics["loss"], epoch) + self.writer.add_scalar("epoch/val_loss", val_metrics["loss"], epoch) + self.writer.add_scalar( + "epoch/val_accuracy", val_metrics["accuracy"], epoch + ) + + if val_metrics["loss"] < self.best_val_loss: + self.best_val_loss = val_metrics["loss"] + self.epochs_without_improvement = 0 + self.save_checkpoint(epoch, val_metrics["loss"], is_best=True) + print(f"New best model saved with val loss: {val_metrics['loss']:.4f}") + else: + self.epochs_without_improvement += 1 + self.save_checkpoint(epoch, val_metrics["loss"], is_best=False) + + patience = self.config.get("early_stopping_patience", 20) + if self.epochs_without_improvement >= patience: + print( + f"Early stopping triggered after {patience} epochs without improvement" + ) + break + + self.writer.close() + + def save_checkpoint(self, epoch: int, val_loss: float, is_best: bool = False): + checkpoint_dir = os.path.join( + self.config.get("output_dir", "runs/similarity"), "checkpoints" + ) + os.makedirs(checkpoint_dir, exist_ok=True) + + checkpoint = { + "epoch": epoch, + "model_state_dict": self.model.state_dict(), + "optimizer_state_dict": self.optimizer.state_dict(), + "val_loss": val_loss, + "config": self.config, + } + + checkpoint_path = os.path.join(checkpoint_dir, f"checkpoint_epoch_{epoch}.pt") + torch.save(checkpoint, checkpoint_path) + + if is_best: + best_path = os.path.join(checkpoint_dir, "best_model.pt") + torch.save(checkpoint, best_path) + + def load_checkpoint(self, checkpoint_path: str): + checkpoint = torch.load(checkpoint_path, map_location=self.device) + self.model.load_state_dict(checkpoint["model_state_dict"]) + self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) + return checkpoint["epoch"], checkpoint["val_loss"] + + +def main(): + parser = argparse.ArgumentParser(description="Train similarity estimation model") + parser.add_argument( + "--data_dir", + type=str, + default=r"C:\Users\admin\Projects\autopilot\datasets\ya_go_maps\images", + ) + parser.add_argument("--batch_size", type=int, default=32) + parser.add_argument("--epochs", type=int, default=100) + parser.add_argument("--learning_rate", type=float, default=2e-4) + parser.add_argument("--image_size", type=int, default=256) + parser.add_argument("--train_split", type=float, default=0.8) + parser.add_argument("--output_dir", type=str, default="runs/similarity") + parser.add_argument("--num_workers", type=int, default=0) + parser.add_argument( + "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu" + ) + + args = parser.parse_args() + + config = { + "data_dir": args.data_dir, + "batch_size": args.batch_size, + "epochs": args.epochs, + "learning_rate": args.learning_rate, + "image_size": (args.image_size, args.image_size), + "train_split": args.train_split, + "output_dir": args.output_dir, + "num_workers": args.num_workers, + "log_interval": 10, + "save_interval": 5, + "early_stopping_patience": 20, + "beta1": 0.5, + "beta2": 0.999, + } + + device = torch.device(args.device) + print(f"Using device: {device}") + + print("Creating data loaders...") + train_loader, val_loader = create_data_loaders( + root_dir=config["data_dir"], + batch_size=config["batch_size"], + train_split=config["train_split"], + num_workers=config["num_workers"], + image_size=config["image_size"], + augment_train=True, + augment_val=False, + device=device, + ) + + print(f"Train batches: {len(train_loader)}") + print(f"Val batches: {len(val_loader)}") + + print("Creating model...") + model = create_similarity_model( + model_type="cnn", + input_size=config["image_size"], + input_channels=3, + hidden_channels=64, + num_blocks=4, + dropout_rate=0.3, + use_batch_norm=True, + ) + + print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}") + + trainer = SimilarityTrainer( + model=model, + train_loader=train_loader, + val_loader=val_loader, + device=device, + config=config, + ) + + print("Starting training...") + trainer.train(config["epochs"]) + + print("Training completed!") + + +if __name__ == "__main__": + main()