from typing import Optional, Tuple import torch import torch.nn as nn import torch.nn.functional as F class SimilarityCNN(nn.Module): """ CNN model for similarity estimation between two images. Takes two images as input and outputs a similarity score between 0 and 1. """ def __init__( self, input_channels: int = 3, hidden_channels: int = 64, num_blocks: int = 4, dropout_rate: float = 0.3, use_batch_norm: bool = True, ): super().__init__() self.input_channels = input_channels self.hidden_channels = hidden_channels self.num_blocks = num_blocks self.dropout_rate = dropout_rate self.use_batch_norm = use_batch_norm self.encoder = self._build_encoder() self.fusion_layers = self._build_fusion_layers() self.regression_head = self._build_regression_head() self._initialize_weights() def _build_encoder(self) -> nn.Module: layers = [] in_channels = self.input_channels out_channels = self.hidden_channels layers.append( nn.Conv2d(in_channels, out_channels, kernel_size=7, stride=2, padding=3) ) if self.use_batch_norm: layers.append(nn.BatchNorm2d(out_channels)) layers.append(nn.ReLU(inplace=True)) layers.append(nn.MaxPool2d(kernel_size=3, stride=2, padding=1)) for i in range(self.num_blocks): block_in_channels = out_channels block_out_channels = out_channels * 2 if i < 2 else out_channels layers.append( ResidualBlock( in_channels=block_in_channels, out_channels=block_out_channels, stride=1 if i == 0 else 2, dropout_rate=self.dropout_rate, use_batch_norm=self.use_batch_norm, ) ) if i < 2: out_channels = block_out_channels return nn.Sequential(*layers) def _build_fusion_layers(self) -> nn.Module: fused_channels = self.hidden_channels * 8 layers = [ nn.Conv2d( fused_channels, self.hidden_channels * 4, kernel_size=3, padding=1 ), nn.BatchNorm2d(self.hidden_channels * 4) if self.use_batch_norm else nn.Identity(), nn.ReLU(inplace=True), nn.Dropout2d(self.dropout_rate), nn.Conv2d( self.hidden_channels * 4, self.hidden_channels * 2, kernel_size=3, padding=1, ), nn.BatchNorm2d(self.hidden_channels * 2) if self.use_batch_norm else nn.Identity(), nn.ReLU(inplace=True), nn.Dropout2d(self.dropout_rate), nn.AdaptiveAvgPool2d((1, 1)), ] return nn.Sequential(*layers) def _build_regression_head(self) -> nn.Module: input_features = self.hidden_channels * 2 layers = [ nn.Flatten(), nn.Linear(input_features, 512), nn.BatchNorm1d(512) if self.use_batch_norm else nn.Identity(), nn.ReLU(inplace=True), nn.Dropout(self.dropout_rate), nn.Linear(512, 256), nn.BatchNorm1d(256) if self.use_batch_norm else nn.Identity(), nn.ReLU(inplace=True), nn.Dropout(self.dropout_rate), nn.Linear(256, 128), nn.BatchNorm1d(128) if self.use_batch_norm else nn.Identity(), nn.ReLU(inplace=True), nn.Dropout(self.dropout_rate), nn.Linear(128, 1), nn.Sigmoid(), ] return nn.Sequential(*layers) def _initialize_weights(self): for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) elif isinstance(m, nn.Linear): nn.init.normal_(m.weight, 0, 0.01) nn.init.constant_(m.bias, 0) def forward( self, img1: torch.Tensor, img2: torch.Tensor, ) -> torch.Tensor: features1 = self.encoder(img1) features2 = self.encoder(img2) combined_features = torch.cat([features1, features2], dim=1) fused_features = self.fusion_layers(combined_features) similarity = self.regression_head(fused_features) return similarity def predict_similarity( self, img1: torch.Tensor, img2: torch.Tensor, ) -> torch.Tensor: original_training = self.training self.eval() with torch.no_grad(): similarity = self.forward(img1, img2) if original_training: self.train() return similarity class ResidualBlock(nn.Module): def __init__( self, in_channels: int, out_channels: int, stride: int = 1, dropout_rate: float = 0.3, use_batch_norm: bool = True, ): super().__init__() self.conv1 = nn.Conv2d( in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False, ) self.bn1 = nn.BatchNorm2d(out_channels) if use_batch_norm else nn.Identity() self.relu1 = nn.ReLU(inplace=True) self.dropout1 = ( nn.Dropout2d(dropout_rate) if dropout_rate > 0 else nn.Identity() ) self.conv2 = nn.Conv2d( out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False ) self.bn2 = nn.BatchNorm2d(out_channels) if use_batch_norm else nn.Identity() self.relu2 = nn.ReLU(inplace=True) self.dropout2 = ( nn.Dropout2d(dropout_rate) if dropout_rate > 0 else nn.Identity() ) self.shortcut = nn.Sequential() if stride != 1 or in_channels != out_channels: self.shortcut = nn.Sequential( nn.Conv2d( in_channels, out_channels, kernel_size=1, stride=stride, bias=False ), nn.BatchNorm2d(out_channels) if use_batch_norm else nn.Identity(), ) def forward(self, x: torch.Tensor) -> torch.Tensor: identity = self.shortcut(x) out = self.conv1(x) out = self.bn1(out) out = self.relu1(out) out = self.dropout1(out) out = self.conv2(out) out = self.bn2(out) out += identity out = self.relu2(out) out = self.dropout2(out) return out class SimilarityLoss(nn.Module): def __init__(self): super().__init__() self.criterion = nn.BCELoss() def forward( self, pred_similarity: torch.Tensor, target_same: torch.Tensor, ) -> torch.Tensor: return self.criterion(pred_similarity, target_same) def compute_metrics( self, pred_similarity: torch.Tensor, target_same: torch.Tensor, threshold: float = 0.5, ) -> dict: with torch.no_grad(): pred_binary = (pred_similarity > threshold).float() target_binary = (target_same > 0.5).float() correct = (pred_binary == target_binary).float() accuracy = correct.mean().item() tp = ((pred_binary == 1) & (target_binary == 1)).float().sum().item() fp = ((pred_binary == 1) & (target_binary == 0)).float().sum().item() fn = ((pred_binary == 0) & (target_binary == 1)).float().sum().item() tn = ((pred_binary == 0) & (target_binary == 0)).float().sum().item() precision = tp / (tp + fp + 1e-8) recall = tp / (tp + fn + 1e-8) f1 = 2 * precision * recall / (precision + recall + 1e-8) return { "accuracy": accuracy, "precision": precision, "recall": recall, "f1": f1, "mean_similarity": pred_similarity.mean().item(), } def create_similarity_model( model_type: str = "cnn", input_size: Tuple[int, int] = (256, 256), **kwargs, ) -> nn.Module: if model_type == "cnn": return SimilarityCNN(**kwargs) else: raise ValueError(f"Unknown model type: {model_type}") if __name__ == "__main__": device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") model = SimilarityCNN( input_channels=3, hidden_channels=64, num_blocks=4, dropout_rate=0.3, use_batch_norm=True, ).to(device) print( f"Model created with {sum(p.numel() for p in model.parameters()):,} parameters" ) batch_size = 4 height, width = 256, 256 img1 = torch.randn(batch_size, 3, height, width).to(device) img2 = torch.randn(batch_size, 3, height, width).to(device) print("\nTesting forward pass...") output = model(img1, img2) print(f"Output shape: {output.shape}") print(f"Sample output: {output[0].item():.4f}") print("\nTesting prediction...") pred = model.predict_similarity(img1, img2) print(f"Prediction shape: {pred.shape}") print("\nTesting loss function...") target = torch.rand(batch_size, 1).to(device) loss_fn = SimilarityLoss().to(device) loss = loss_fn(output, target) print(f"Loss value: {loss.item():.6f}") print("\nTesting metrics...") metrics = loss_fn.compute_metrics(output, target) for key, value in metrics.items(): print(f"{key}: {value:.6f}") print("\nAll tests completed successfully!")