from typing import Tuple import torch import torch.nn as nn import torch.nn.functional as F from torchvision import models class SimilarityCNN(nn.Module): """ Модель для оценки схожести двух изображений на базе предобученного бэкбона. Интерфейс совместим с исходной: - forward(img1, img2) -> тензор (B, 1) со скором в [0, 1] - predict_similarity(img1, img2) -> тензор (B, 1) без градиентов """ def __init__( self, input_channels: int = 3, backbone_name: str = "resnet18", pretrained: bool = True, dropout_rate: float = 0.3, use_batch_norm: bool = True, ): super().__init__() self.input_channels = input_channels self.backbone_name = backbone_name self.pretrained = pretrained self.dropout_rate = dropout_rate self.use_batch_norm = use_batch_norm # 1. Создаём бэкбон и берём фичи до последнего FC backbone = self._create_backbone(backbone_name, pretrained) # Для ResNet18 выход фичей = 512 self.feature_dim = backbone.fc.in_features # Заменяем classification head на Identity, чтобы получать только признаки backbone.fc = nn.Identity() self.backbone = backbone # 2. Голова для сравнения двух векторов признаков # Вход: [f1, f2, |f1 - f2|, f1 * f2] => 4 * feature_dim compare_input_dim = self.feature_dim * 4 layers = [ nn.Linear(compare_input_dim, 512), nn.BatchNorm1d(512) if use_batch_norm else nn.Identity(), nn.ReLU(inplace=True), nn.Dropout(dropout_rate), nn.Linear(512, 256), nn.BatchNorm1d(256) if use_batch_norm else nn.Identity(), nn.ReLU(inplace=True), nn.Dropout(dropout_rate), nn.Linear(256, 1), nn.Sigmoid(), # выход в [0, 1] ] self.head = nn.Sequential(*layers) def _create_backbone(self, name: str, pretrained: bool) -> nn.Module: name = name.lower() if name == "resnet18": model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1 if pretrained else None) elif name == "resnet34": model = models.resnet34(weights=models.ResNet34_Weights.IMAGENET1K_V1 if pretrained else None) else: raise ValueError(f"Unsupported backbone: {name}") # Если у тебя не 3 канала, можно добавить адаптер 1x1 conv перед model.conv1 if self.input_channels != 3: old_conv = model.conv1 model.conv1 = nn.Conv2d( self.input_channels, old_conv.out_channels, kernel_size=old_conv.kernel_size, stride=old_conv.stride, padding=old_conv.padding, bias=old_conv.bias is not None, ) return model def _extract_features(self, x: torch.Tensor) -> torch.Tensor: """ Прогоняет одно изображение через бэкбон и возвращает вектор признаков (B, feature_dim). Для ResNet: это эквивалентно model.forward(x), когда fc = Identity. """ return self.backbone(x) # (B, feature_dim) def forward(self, img1: torch.Tensor, img2: torch.Tensor) -> torch.Tensor: """ img1, img2: (B, C, H, W) -> similarity: (B, 1) """ f1 = self._extract_features(img1) # (B, D) f2 = self._extract_features(img2) # (B, D) # Вектор сравнения diff = torch.abs(f1 - f2) prod = f1 * f2 combined = torch.cat([f1, f2, diff, prod], dim=1) # (B, 4D) similarity = self.head(combined) # (B, 1) в [0, 1] return similarity def predict_similarity(self, img1: torch.Tensor, img2: torch.Tensor) -> torch.Tensor: """ Инференс без градиентов, интерфейс как у исходной модели. """ was_training = self.training self.eval() with torch.no_grad(): sim = self.forward(img1, img2) if was_training: self.train() return sim class SimilarityLoss(nn.Module): """ Оставляю тот же интерфейс loss, что и в твоём коде. Если таргет бинарный (0/1), BCELoss подходит. """ def __init__(self): super().__init__() self.criterion = nn.BCELoss() def forward(self, pred_similarity: torch.Tensor, target_same: torch.Tensor) -> torch.Tensor: return self.criterion(pred_similarity, target_same) def compute_metrics( self, pred_similarity: torch.Tensor, target_same: torch.Tensor, threshold: float = 0.5, ) -> dict: with torch.no_grad(): pred_binary = (pred_similarity > threshold).float() target_binary = (target_same > 0.5).float() correct = (pred_binary == target_binary).float() accuracy = correct.mean().item() tp = ((pred_binary == 1) & (target_binary == 1)).float().sum().item() fp = ((pred_binary == 1) & (target_binary == 0)).float().sum().item() fn = ((pred_binary == 0) & (target_binary == 1)).float().sum().item() tn = ((pred_binary == 0) & (target_binary == 0)).float().sum().item() precision = tp / (tp + fp + 1e-8) recall = tp / (tp + fn + 1e-8) f1 = 2 * precision * recall / (precision + recall + 1e-8) return { "accuracy": accuracy, "precision": precision, "recall": recall, "f1": f1, "mean_similarity": pred_similarity.mean().item(), } def create_similarity_model( model_type: str = "backbone", input_size: Tuple[int, int] = (256, 256), **kwargs, ) -> nn.Module: """ Аналог вашей фабрики, но с новым типом модели. """ if model_type == "backbone": return SimilarityCNN(**kwargs) else: raise ValueError(f"Unknown model type: {model_type}") if __name__ == "__main__": device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") model = SimilarityCNN( input_channels=3, backbone_name="resnet18", pretrained=True, dropout_rate=0.3, use_batch_norm=True, ).to(device) print( f"Model created with {sum(p.numel() for p in model.parameters()):,} parameters" ) batch_size = 4 height, width = 256, 256 img1 = torch.randn(batch_size, 3, height, width).to(device) img2 = torch.randn(batch_size, 3, height, width).to(device) print("\nTesting forward pass...") output = model(img1, img2) print(f"Output shape: {output.shape}") print(f"Sample output: {output[0].item():.4f}") print("\nTesting prediction...") pred = model.predict_similarity(img1, img2) print(f"Prediction shape: {pred.shape}") print("\nTesting loss function...") target = torch.rand(batch_size, 1).to(device) loss_fn = SimilarityLoss().to(device) loss = loss_fn(output, target) print(f"Loss value: {loss.item():.6f}") print("\nTesting metrics...") metrics = loss_fn.compute_metrics(output, target) for key, value in metrics.items(): print(f"{key}: {value:.6f}") print("\nAll tests completed successfully!")