323 lines
10 KiB
Python
323 lines
10 KiB
Python
from typing import Optional, Tuple
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
|
|
class SimilarityCNN(nn.Module):
|
|
"""
|
|
CNN model for similarity estimation between two images.
|
|
|
|
Takes two images as input and outputs a similarity score between 0 and 1.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
input_channels: int = 3,
|
|
hidden_channels: int = 64,
|
|
num_blocks: int = 4,
|
|
dropout_rate: float = 0.3,
|
|
use_batch_norm: bool = True,
|
|
):
|
|
super().__init__()
|
|
|
|
self.input_channels = input_channels
|
|
self.hidden_channels = hidden_channels
|
|
self.num_blocks = num_blocks
|
|
self.dropout_rate = dropout_rate
|
|
self.use_batch_norm = use_batch_norm
|
|
|
|
self.encoder = self._build_encoder()
|
|
|
|
self.fusion_layers = self._build_fusion_layers()
|
|
|
|
self.regression_head = self._build_regression_head()
|
|
|
|
self._initialize_weights()
|
|
|
|
def _build_encoder(self) -> nn.Module:
|
|
layers = []
|
|
in_channels = self.input_channels
|
|
out_channels = self.hidden_channels
|
|
|
|
layers.append(
|
|
nn.Conv2d(in_channels, out_channels, kernel_size=7, stride=2, padding=3)
|
|
)
|
|
if self.use_batch_norm:
|
|
layers.append(nn.BatchNorm2d(out_channels))
|
|
layers.append(nn.ReLU(inplace=True))
|
|
layers.append(nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
|
|
|
|
for i in range(self.num_blocks):
|
|
block_in_channels = out_channels
|
|
block_out_channels = out_channels * 2 if i < 2 else out_channels
|
|
|
|
layers.append(
|
|
ResidualBlock(
|
|
in_channels=block_in_channels,
|
|
out_channels=block_out_channels,
|
|
stride=1 if i == 0 else 2,
|
|
dropout_rate=self.dropout_rate,
|
|
use_batch_norm=self.use_batch_norm,
|
|
)
|
|
)
|
|
|
|
if i < 2:
|
|
out_channels = block_out_channels
|
|
|
|
return nn.Sequential(*layers)
|
|
|
|
def _build_fusion_layers(self) -> nn.Module:
|
|
fused_channels = self.hidden_channels * 8
|
|
|
|
layers = [
|
|
nn.Conv2d(
|
|
fused_channels, self.hidden_channels * 4, kernel_size=3, padding=1
|
|
),
|
|
nn.BatchNorm2d(self.hidden_channels * 4)
|
|
if self.use_batch_norm
|
|
else nn.Identity(),
|
|
nn.ReLU(inplace=True),
|
|
nn.Dropout2d(self.dropout_rate),
|
|
nn.Conv2d(
|
|
self.hidden_channels * 4,
|
|
self.hidden_channels * 2,
|
|
kernel_size=3,
|
|
padding=1,
|
|
),
|
|
nn.BatchNorm2d(self.hidden_channels * 2)
|
|
if self.use_batch_norm
|
|
else nn.Identity(),
|
|
nn.ReLU(inplace=True),
|
|
nn.Dropout2d(self.dropout_rate),
|
|
nn.AdaptiveAvgPool2d((1, 1)),
|
|
]
|
|
|
|
return nn.Sequential(*layers)
|
|
|
|
def _build_regression_head(self) -> nn.Module:
|
|
input_features = self.hidden_channels * 2
|
|
|
|
layers = [
|
|
nn.Flatten(),
|
|
nn.Linear(input_features, 512),
|
|
nn.BatchNorm1d(512) if self.use_batch_norm else nn.Identity(),
|
|
nn.ReLU(inplace=True),
|
|
nn.Dropout(self.dropout_rate),
|
|
nn.Linear(512, 256),
|
|
nn.BatchNorm1d(256) if self.use_batch_norm else nn.Identity(),
|
|
nn.ReLU(inplace=True),
|
|
nn.Dropout(self.dropout_rate),
|
|
nn.Linear(256, 128),
|
|
nn.BatchNorm1d(128) if self.use_batch_norm else nn.Identity(),
|
|
nn.ReLU(inplace=True),
|
|
nn.Dropout(self.dropout_rate),
|
|
nn.Linear(128, 1),
|
|
nn.Sigmoid(),
|
|
]
|
|
|
|
return nn.Sequential(*layers)
|
|
|
|
def _initialize_weights(self):
|
|
for m in self.modules():
|
|
if isinstance(m, nn.Conv2d):
|
|
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
|
|
if m.bias is not None:
|
|
nn.init.constant_(m.bias, 0)
|
|
elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
|
|
nn.init.constant_(m.weight, 1)
|
|
nn.init.constant_(m.bias, 0)
|
|
elif isinstance(m, nn.Linear):
|
|
nn.init.normal_(m.weight, 0, 0.01)
|
|
nn.init.constant_(m.bias, 0)
|
|
|
|
def forward(
|
|
self,
|
|
img1: torch.Tensor,
|
|
img2: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
features1 = self.encoder(img1)
|
|
features2 = self.encoder(img2)
|
|
|
|
combined_features = torch.cat([features1, features2], dim=1)
|
|
|
|
fused_features = self.fusion_layers(combined_features)
|
|
|
|
similarity = self.regression_head(fused_features)
|
|
|
|
return similarity
|
|
|
|
def predict_similarity(
|
|
self,
|
|
img1: torch.Tensor,
|
|
img2: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
original_training = self.training
|
|
self.eval()
|
|
with torch.no_grad():
|
|
similarity = self.forward(img1, img2)
|
|
if original_training:
|
|
self.train()
|
|
return similarity
|
|
|
|
|
|
class ResidualBlock(nn.Module):
|
|
def __init__(
|
|
self,
|
|
in_channels: int,
|
|
out_channels: int,
|
|
stride: int = 1,
|
|
dropout_rate: float = 0.3,
|
|
use_batch_norm: bool = True,
|
|
):
|
|
super().__init__()
|
|
|
|
self.conv1 = nn.Conv2d(
|
|
in_channels,
|
|
out_channels,
|
|
kernel_size=3,
|
|
stride=stride,
|
|
padding=1,
|
|
bias=False,
|
|
)
|
|
self.bn1 = nn.BatchNorm2d(out_channels) if use_batch_norm else nn.Identity()
|
|
self.relu1 = nn.ReLU(inplace=True)
|
|
self.dropout1 = (
|
|
nn.Dropout2d(dropout_rate) if dropout_rate > 0 else nn.Identity()
|
|
)
|
|
|
|
self.conv2 = nn.Conv2d(
|
|
out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False
|
|
)
|
|
self.bn2 = nn.BatchNorm2d(out_channels) if use_batch_norm else nn.Identity()
|
|
self.relu2 = nn.ReLU(inplace=True)
|
|
self.dropout2 = (
|
|
nn.Dropout2d(dropout_rate) if dropout_rate > 0 else nn.Identity()
|
|
)
|
|
|
|
self.shortcut = nn.Sequential()
|
|
if stride != 1 or in_channels != out_channels:
|
|
self.shortcut = nn.Sequential(
|
|
nn.Conv2d(
|
|
in_channels, out_channels, kernel_size=1, stride=stride, bias=False
|
|
),
|
|
nn.BatchNorm2d(out_channels) if use_batch_norm else nn.Identity(),
|
|
)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
identity = self.shortcut(x)
|
|
|
|
out = self.conv1(x)
|
|
out = self.bn1(out)
|
|
out = self.relu1(out)
|
|
out = self.dropout1(out)
|
|
|
|
out = self.conv2(out)
|
|
out = self.bn2(out)
|
|
|
|
out += identity
|
|
out = self.relu2(out)
|
|
out = self.dropout2(out)
|
|
|
|
return out
|
|
|
|
|
|
class SimilarityLoss(nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.criterion = nn.BCELoss()
|
|
|
|
def forward(
|
|
self,
|
|
pred_similarity: torch.Tensor,
|
|
target_same: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
return self.criterion(pred_similarity, target_same)
|
|
|
|
def compute_metrics(
|
|
self,
|
|
pred_similarity: torch.Tensor,
|
|
target_same: torch.Tensor,
|
|
threshold: float = 0.5,
|
|
) -> dict:
|
|
with torch.no_grad():
|
|
pred_binary = (pred_similarity > threshold).float()
|
|
target_binary = (target_same > 0.5).float()
|
|
|
|
correct = (pred_binary == target_binary).float()
|
|
accuracy = correct.mean().item()
|
|
|
|
tp = ((pred_binary == 1) & (target_binary == 1)).float().sum().item()
|
|
fp = ((pred_binary == 1) & (target_binary == 0)).float().sum().item()
|
|
fn = ((pred_binary == 0) & (target_binary == 1)).float().sum().item()
|
|
tn = ((pred_binary == 0) & (target_binary == 0)).float().sum().item()
|
|
|
|
precision = tp / (tp + fp + 1e-8)
|
|
recall = tp / (tp + fn + 1e-8)
|
|
f1 = 2 * precision * recall / (precision + recall + 1e-8)
|
|
|
|
return {
|
|
"accuracy": accuracy,
|
|
"precision": precision,
|
|
"recall": recall,
|
|
"f1": f1,
|
|
"mean_similarity": pred_similarity.mean().item(),
|
|
}
|
|
|
|
|
|
def create_similarity_model(
|
|
model_type: str = "cnn",
|
|
input_size: Tuple[int, int] = (256, 256),
|
|
**kwargs,
|
|
) -> nn.Module:
|
|
if model_type == "cnn":
|
|
return SimilarityCNN(**kwargs)
|
|
else:
|
|
raise ValueError(f"Unknown model type: {model_type}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
print(f"Using device: {device}")
|
|
|
|
model = SimilarityCNN(
|
|
input_channels=3,
|
|
hidden_channels=64,
|
|
num_blocks=4,
|
|
dropout_rate=0.3,
|
|
use_batch_norm=True,
|
|
).to(device)
|
|
|
|
print(
|
|
f"Model created with {sum(p.numel() for p in model.parameters()):,} parameters"
|
|
)
|
|
|
|
batch_size = 4
|
|
height, width = 256, 256
|
|
|
|
img1 = torch.randn(batch_size, 3, height, width).to(device)
|
|
img2 = torch.randn(batch_size, 3, height, width).to(device)
|
|
|
|
print("\nTesting forward pass...")
|
|
output = model(img1, img2)
|
|
print(f"Output shape: {output.shape}")
|
|
print(f"Sample output: {output[0].item():.4f}")
|
|
|
|
print("\nTesting prediction...")
|
|
pred = model.predict_similarity(img1, img2)
|
|
print(f"Prediction shape: {pred.shape}")
|
|
|
|
print("\nTesting loss function...")
|
|
target = torch.rand(batch_size, 1).to(device)
|
|
loss_fn = SimilarityLoss().to(device)
|
|
loss = loss_fn(output, target)
|
|
print(f"Loss value: {loss.item():.6f}")
|
|
|
|
print("\nTesting metrics...")
|
|
metrics = loss_fn.compute_metrics(output, target)
|
|
for key, value in metrics.items():
|
|
print(f"{key}: {value:.6f}")
|
|
|
|
print("\nAll tests completed successfully!")
|