Files

220 lines
7.9 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from typing import Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
class SimilarityCNN(nn.Module):
"""
Модель для оценки схожести двух изображений на базе предобученного бэкбона.
Интерфейс совместим с исходной:
- forward(img1, img2) -> тензор (B, 1) со скором в [0, 1]
- predict_similarity(img1, img2) -> тензор (B, 1) без градиентов
"""
def __init__(
self,
input_channels: int = 3,
backbone_name: str = "resnet18",
pretrained: bool = True,
dropout_rate: float = 0.3,
use_batch_norm: bool = True,
):
super().__init__()
self.input_channels = input_channels
self.backbone_name = backbone_name
self.pretrained = pretrained
self.dropout_rate = dropout_rate
self.use_batch_norm = use_batch_norm
# 1. Создаём бэкбон и берём фичи до последнего FC
backbone = self._create_backbone(backbone_name, pretrained)
# Для ResNet18 выход фичей = 512
self.feature_dim = backbone.fc.in_features
# Заменяем classification head на Identity, чтобы получать только признаки
backbone.fc = nn.Identity()
self.backbone = backbone
# 2. Голова для сравнения двух векторов признаков
# Вход: [f1, f2, |f1 - f2|, f1 * f2] => 4 * feature_dim
compare_input_dim = self.feature_dim * 4
layers = [
nn.Linear(compare_input_dim, 512),
nn.BatchNorm1d(512) if use_batch_norm else nn.Identity(),
nn.ReLU(inplace=True),
nn.Dropout(dropout_rate),
nn.Linear(512, 256),
nn.BatchNorm1d(256) if use_batch_norm else nn.Identity(),
nn.ReLU(inplace=True),
nn.Dropout(dropout_rate),
nn.Linear(256, 1),
nn.Sigmoid(), # выход в [0, 1]
]
self.head = nn.Sequential(*layers)
def _create_backbone(self, name: str, pretrained: bool) -> nn.Module:
name = name.lower()
if name == "resnet18":
model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1 if pretrained else None)
elif name == "resnet34":
model = models.resnet34(weights=models.ResNet34_Weights.IMAGENET1K_V1 if pretrained else None)
else:
raise ValueError(f"Unsupported backbone: {name}")
# Если у тебя не 3 канала, можно добавить адаптер 1x1 conv перед model.conv1
if self.input_channels != 3:
old_conv = model.conv1
model.conv1 = nn.Conv2d(
self.input_channels,
old_conv.out_channels,
kernel_size=old_conv.kernel_size,
stride=old_conv.stride,
padding=old_conv.padding,
bias=old_conv.bias is not None,
)
return model
def _extract_features(self, x: torch.Tensor) -> torch.Tensor:
"""
Прогоняет одно изображение через бэкбон и возвращает вектор признаков (B, feature_dim).
Для ResNet: это эквивалентно model.forward(x), когда fc = Identity.
"""
return self.backbone(x) # (B, feature_dim)
def forward(self, img1: torch.Tensor, img2: torch.Tensor) -> torch.Tensor:
"""
img1, img2: (B, C, H, W) -> similarity: (B, 1)
"""
f1 = self._extract_features(img1) # (B, D)
f2 = self._extract_features(img2) # (B, D)
# Вектор сравнения
diff = torch.abs(f1 - f2)
prod = f1 * f2
combined = torch.cat([f1, f2, diff, prod], dim=1) # (B, 4D)
similarity = self.head(combined) # (B, 1) в [0, 1]
return similarity
def predict_similarity(self, img1: torch.Tensor, img2: torch.Tensor) -> torch.Tensor:
"""
Инференс без градиентов, интерфейс как у исходной модели.
"""
was_training = self.training
self.eval()
with torch.no_grad():
sim = self.forward(img1, img2)
if was_training:
self.train()
return sim
class SimilarityLoss(nn.Module):
"""
Оставляю тот же интерфейс loss, что и в твоём коде.
Если таргет бинарный (0/1), BCELoss подходит.
"""
def __init__(self):
super().__init__()
self.criterion = nn.BCELoss()
def forward(self, pred_similarity: torch.Tensor, target_same: torch.Tensor) -> torch.Tensor:
return self.criterion(pred_similarity, target_same)
def compute_metrics(
self,
pred_similarity: torch.Tensor,
target_same: torch.Tensor,
threshold: float = 0.5,
) -> dict:
with torch.no_grad():
pred_binary = (pred_similarity > threshold).float()
target_binary = (target_same > 0.5).float()
correct = (pred_binary == target_binary).float()
accuracy = correct.mean().item()
tp = ((pred_binary == 1) & (target_binary == 1)).float().sum().item()
fp = ((pred_binary == 1) & (target_binary == 0)).float().sum().item()
fn = ((pred_binary == 0) & (target_binary == 1)).float().sum().item()
tn = ((pred_binary == 0) & (target_binary == 0)).float().sum().item()
precision = tp / (tp + fp + 1e-8)
recall = tp / (tp + fn + 1e-8)
f1 = 2 * precision * recall / (precision + recall + 1e-8)
return {
"accuracy": accuracy,
"precision": precision,
"recall": recall,
"f1": f1,
"mean_similarity": pred_similarity.mean().item(),
}
def create_similarity_model(
model_type: str = "backbone",
input_size: Tuple[int, int] = (256, 256),
**kwargs,
) -> nn.Module:
"""
Аналог вашей фабрики, но с новым типом модели.
"""
if model_type == "backbone":
return SimilarityCNN(**kwargs)
else:
raise ValueError(f"Unknown model type: {model_type}")
if __name__ == "__main__":
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
model = SimilarityCNN(
input_channels=3,
backbone_name="resnet18",
pretrained=True,
dropout_rate=0.3,
use_batch_norm=True,
).to(device)
print(
f"Model created with {sum(p.numel() for p in model.parameters()):,} parameters"
)
batch_size = 4
height, width = 256, 256
img1 = torch.randn(batch_size, 3, height, width).to(device)
img2 = torch.randn(batch_size, 3, height, width).to(device)
print("\nTesting forward pass...")
output = model(img1, img2)
print(f"Output shape: {output.shape}")
print(f"Sample output: {output[0].item():.4f}")
print("\nTesting prediction...")
pred = model.predict_similarity(img1, img2)
print(f"Prediction shape: {pred.shape}")
print("\nTesting loss function...")
target = torch.rand(batch_size, 1).to(device)
loss_fn = SimilarityLoss().to(device)
loss = loss_fn(output, target)
print(f"Loss value: {loss.item():.6f}")
print("\nTesting metrics...")
metrics = loss_fn.compute_metrics(output, target)
for key, value in metrics.items():
print(f"{key}: {value:.6f}")
print("\nAll tests completed successfully!")