220 lines
7.9 KiB
Python
220 lines
7.9 KiB
Python
from typing import Tuple
|
||
|
||
import torch
|
||
import torch.nn as nn
|
||
import torch.nn.functional as F
|
||
from torchvision import models
|
||
|
||
|
||
class SimilarityCNN(nn.Module):
|
||
"""
|
||
Модель для оценки схожести двух изображений на базе предобученного бэкбона.
|
||
|
||
Интерфейс совместим с исходной:
|
||
- forward(img1, img2) -> тензор (B, 1) со скором в [0, 1]
|
||
- predict_similarity(img1, img2) -> тензор (B, 1) без градиентов
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
input_channels: int = 3,
|
||
backbone_name: str = "resnet18",
|
||
pretrained: bool = True,
|
||
dropout_rate: float = 0.3,
|
||
use_batch_norm: bool = True,
|
||
):
|
||
super().__init__()
|
||
|
||
self.input_channels = input_channels
|
||
self.backbone_name = backbone_name
|
||
self.pretrained = pretrained
|
||
self.dropout_rate = dropout_rate
|
||
self.use_batch_norm = use_batch_norm
|
||
|
||
# 1. Создаём бэкбон и берём фичи до последнего FC
|
||
backbone = self._create_backbone(backbone_name, pretrained)
|
||
|
||
# Для ResNet18 выход фичей = 512
|
||
self.feature_dim = backbone.fc.in_features
|
||
# Заменяем classification head на Identity, чтобы получать только признаки
|
||
backbone.fc = nn.Identity()
|
||
self.backbone = backbone
|
||
|
||
# 2. Голова для сравнения двух векторов признаков
|
||
# Вход: [f1, f2, |f1 - f2|, f1 * f2] => 4 * feature_dim
|
||
compare_input_dim = self.feature_dim * 4
|
||
|
||
layers = [
|
||
nn.Linear(compare_input_dim, 512),
|
||
nn.BatchNorm1d(512) if use_batch_norm else nn.Identity(),
|
||
nn.ReLU(inplace=True),
|
||
nn.Dropout(dropout_rate),
|
||
|
||
nn.Linear(512, 256),
|
||
nn.BatchNorm1d(256) if use_batch_norm else nn.Identity(),
|
||
nn.ReLU(inplace=True),
|
||
nn.Dropout(dropout_rate),
|
||
|
||
nn.Linear(256, 1),
|
||
nn.Sigmoid(), # выход в [0, 1]
|
||
]
|
||
self.head = nn.Sequential(*layers)
|
||
|
||
def _create_backbone(self, name: str, pretrained: bool) -> nn.Module:
|
||
name = name.lower()
|
||
if name == "resnet18":
|
||
model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1 if pretrained else None)
|
||
elif name == "resnet34":
|
||
model = models.resnet34(weights=models.ResNet34_Weights.IMAGENET1K_V1 if pretrained else None)
|
||
else:
|
||
raise ValueError(f"Unsupported backbone: {name}")
|
||
# Если у тебя не 3 канала, можно добавить адаптер 1x1 conv перед model.conv1
|
||
if self.input_channels != 3:
|
||
old_conv = model.conv1
|
||
model.conv1 = nn.Conv2d(
|
||
self.input_channels,
|
||
old_conv.out_channels,
|
||
kernel_size=old_conv.kernel_size,
|
||
stride=old_conv.stride,
|
||
padding=old_conv.padding,
|
||
bias=old_conv.bias is not None,
|
||
)
|
||
return model
|
||
|
||
def _extract_features(self, x: torch.Tensor) -> torch.Tensor:
|
||
"""
|
||
Прогоняет одно изображение через бэкбон и возвращает вектор признаков (B, feature_dim).
|
||
Для ResNet: это эквивалентно model.forward(x), когда fc = Identity.
|
||
"""
|
||
return self.backbone(x) # (B, feature_dim)
|
||
|
||
def forward(self, img1: torch.Tensor, img2: torch.Tensor) -> torch.Tensor:
|
||
"""
|
||
img1, img2: (B, C, H, W) -> similarity: (B, 1)
|
||
"""
|
||
f1 = self._extract_features(img1) # (B, D)
|
||
f2 = self._extract_features(img2) # (B, D)
|
||
|
||
# Вектор сравнения
|
||
diff = torch.abs(f1 - f2)
|
||
prod = f1 * f2
|
||
combined = torch.cat([f1, f2, diff, prod], dim=1) # (B, 4D)
|
||
|
||
similarity = self.head(combined) # (B, 1) в [0, 1]
|
||
return similarity
|
||
|
||
def predict_similarity(self, img1: torch.Tensor, img2: torch.Tensor) -> torch.Tensor:
|
||
"""
|
||
Инференс без градиентов, интерфейс как у исходной модели.
|
||
"""
|
||
was_training = self.training
|
||
self.eval()
|
||
with torch.no_grad():
|
||
sim = self.forward(img1, img2)
|
||
if was_training:
|
||
self.train()
|
||
return sim
|
||
|
||
|
||
class SimilarityLoss(nn.Module):
|
||
"""
|
||
Оставляю тот же интерфейс loss, что и в твоём коде.
|
||
Если таргет бинарный (0/1), BCELoss подходит.
|
||
"""
|
||
|
||
def __init__(self):
|
||
super().__init__()
|
||
self.criterion = nn.BCELoss()
|
||
|
||
def forward(self, pred_similarity: torch.Tensor, target_same: torch.Tensor) -> torch.Tensor:
|
||
return self.criterion(pred_similarity, target_same)
|
||
|
||
def compute_metrics(
|
||
self,
|
||
pred_similarity: torch.Tensor,
|
||
target_same: torch.Tensor,
|
||
threshold: float = 0.5,
|
||
) -> dict:
|
||
with torch.no_grad():
|
||
pred_binary = (pred_similarity > threshold).float()
|
||
target_binary = (target_same > 0.5).float()
|
||
|
||
correct = (pred_binary == target_binary).float()
|
||
accuracy = correct.mean().item()
|
||
|
||
tp = ((pred_binary == 1) & (target_binary == 1)).float().sum().item()
|
||
fp = ((pred_binary == 1) & (target_binary == 0)).float().sum().item()
|
||
fn = ((pred_binary == 0) & (target_binary == 1)).float().sum().item()
|
||
tn = ((pred_binary == 0) & (target_binary == 0)).float().sum().item()
|
||
|
||
precision = tp / (tp + fp + 1e-8)
|
||
recall = tp / (tp + fn + 1e-8)
|
||
f1 = 2 * precision * recall / (precision + recall + 1e-8)
|
||
|
||
return {
|
||
"accuracy": accuracy,
|
||
"precision": precision,
|
||
"recall": recall,
|
||
"f1": f1,
|
||
"mean_similarity": pred_similarity.mean().item(),
|
||
}
|
||
|
||
|
||
def create_similarity_model(
|
||
model_type: str = "backbone",
|
||
input_size: Tuple[int, int] = (256, 256),
|
||
**kwargs,
|
||
) -> nn.Module:
|
||
"""
|
||
Аналог вашей фабрики, но с новым типом модели.
|
||
"""
|
||
if model_type == "backbone":
|
||
return SimilarityCNN(**kwargs)
|
||
else:
|
||
raise ValueError(f"Unknown model type: {model_type}")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||
print(f"Using device: {device}")
|
||
|
||
model = SimilarityCNN(
|
||
input_channels=3,
|
||
backbone_name="resnet18",
|
||
pretrained=True,
|
||
dropout_rate=0.3,
|
||
use_batch_norm=True,
|
||
).to(device)
|
||
|
||
print(
|
||
f"Model created with {sum(p.numel() for p in model.parameters()):,} parameters"
|
||
)
|
||
|
||
batch_size = 4
|
||
height, width = 256, 256
|
||
|
||
img1 = torch.randn(batch_size, 3, height, width).to(device)
|
||
img2 = torch.randn(batch_size, 3, height, width).to(device)
|
||
|
||
print("\nTesting forward pass...")
|
||
output = model(img1, img2)
|
||
print(f"Output shape: {output.shape}")
|
||
print(f"Sample output: {output[0].item():.4f}")
|
||
|
||
print("\nTesting prediction...")
|
||
pred = model.predict_similarity(img1, img2)
|
||
print(f"Prediction shape: {pred.shape}")
|
||
|
||
print("\nTesting loss function...")
|
||
target = torch.rand(batch_size, 1).to(device)
|
||
loss_fn = SimilarityLoss().to(device)
|
||
loss = loss_fn(output, target)
|
||
print(f"Loss value: {loss.item():.6f}")
|
||
|
||
print("\nTesting metrics...")
|
||
metrics = loss_fn.compute_metrics(output, target)
|
||
for key, value in metrics.items():
|
||
print(f"{key}: {value:.6f}")
|
||
|
||
print("\nAll tests completed successfully!")
|