try custom cnn

This commit is contained in:
2026-04-14 12:37:09 +03:00
parent 2ec0763e6d
commit fc072d798e
4 changed files with 1349 additions and 888 deletions

File diff suppressed because it is too large Load Diff

View File

@@ -4,7 +4,7 @@ import os
import torch
from .dataloader import create_data_loaders, get_dataset_info
from .model import HomographyCNN6, count_parameters
from .model import HomographyCNN6, HomographyHybridCNN, HomographyLoss, count_parameters
from .train import HomographyTrainer
from .analyze import analyze_training
from .utils import config
@@ -29,18 +29,24 @@ train_loader, val_loader = create_data_loaders(
)
logger.info(f"Data loaders created: train={len(train_loader.dataset)}, val={len(val_loader.dataset)}")
model = HomographyCNN6(
# model = HomographyCNN6(
# input_channels=3,
# backbone_name=config["backbone"],
# pretrained=True,
# dropout_rate=config["dropout_rate"]
# )
model = HomographyHybridCNN(
input_channels=3,
backbone_name=config["backbone"],
pretrained=True,
dropout_rate=config["dropout_rate"]
dropout_rate=config["droupout_rate"],
)
logger.info(f"Model created with {count_parameters(model):,} parameters")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {device}")
trainer = HomographyTrainer(model, train_loader, val_loader, device)
trainer = HomographyTrainer(model, train_loader, val_loader, device, HomographyLoss())
logger.info("Starting training...")
trainer.train(config["epochs"])
logger.info("Training completed")

View File

@@ -77,6 +77,126 @@ class HomographyCNN6(nn.Module):
}
class HomographyHybridCNN(nn.Module):
def __init__(self, input_channels=3, use_resnet_layers=2, dropout_rate=0.3):
super().__init__()
if use_resnet_layers == 1:
resnet = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
self.conv1 = resnet.conv1
self.bn1 = resnet.bn1
self.relu = resnet.relu
self.maxpool = resnet.maxpool
conv_out_channels = 64
elif use_resnet_layers == 2:
resnet = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
self.conv1 = resnet.conv1
self.bn1 = resnet.bn1
self.relu = resnet.relu
self.maxpool = resnet.maxpool
self.conv2 = resnet.layer1[0].conv1
self.bn2 = resnet.layer1[0].bn1
self.conv2_2 = resnet.layer1[0].conv2
self.bn2_2 = resnet.layer1[0].bn2
self.relu2 = resnet.layer1[0].relu
self.maxpool2 = resnet.maxpool
conv_out_channels = 64
else:
raise ValueError("use_resnet_layers must be 1 or 2")
self.use_resnet_layers = use_resnet_layers
self.feature_map_size = 64
self.conv_head = nn.Sequential(
nn.Conv2d(conv_out_channels, 128, kernel_size=3, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.Conv2d(128, 256, kernel_size=3, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
nn.MaxPool2d(2),
)
self.global_pool = nn.AdaptiveAvgPool2d((1, 1))
feature_dim = 256 * 4
self.head = nn.Sequential(
nn.Linear(feature_dim, 1024),
nn.ReLU(inplace=True),
nn.Dropout(dropout_rate),
nn.Linear(1024, 512),
nn.ReLU(inplace=True),
nn.Dropout(dropout_rate),
nn.Linear(512, 256),
nn.ReLU(inplace=True),
nn.Dropout(dropout_rate),
nn.Linear(256, 6),
)
self._init_weights()
def _init_weights(self):
for module in self.head.modules():
if isinstance(module, nn.Linear):
nn.init.kaiming_normal_(module.weight, mode='fan_in', nonlinearity='relu')
if module.bias is not None:
nn.init.zeros_(module.bias)
def forward(self, img1, img2):
x1 = self._extract_features(img1)
x2 = self._extract_features(img2)
combined = torch.cat([x1, x2, torch.abs(x1 - x2), x1 * x2], dim=1)
output = self.head(combined)
output = torch.tanh(output)
modified = output.clone()
modified[:, 2:6] = torch.mul(output[:, 2:6], torch.pi)
return modified
def _extract_features(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
if self.use_resnet_layers >= 2:
x = self.conv2(x)
x = self.bn2(x)
x = self.relu(x)
x = self.conv2_2(x)
x = self.bn2_2(x)
x = self.relu2(x)
x = self.maxpool2(x)
x = self.conv_head(x)
x = self.global_pool(x)
x = x.view(x.size(0), -1)
return x
def decode_output(self, output):
tx = output[:, 0]
ty = output[:, 1]
scale = output[:, 5]
angle1 = output[:, 2]
angle2 = output[:, 3]
angle3 = output[:, 4]
return torch.stack([tx, ty, angle1, angle2, angle3, scale], dim=1)
def get_components(self, output):
decoded = self.decode_output(output)
return {
"tx": decoded[:, 0],
"ty": decoded[:, 1],
"rx": decoded[:, 2],
"ry": decoded[:, 3],
"rz": decoded[:, 4],
"scale": decoded[:, 5],
}
class HomographyLoss6(nn.Module):
def __init__(self, angle_loss_weight=1.0, trans_loss_weight=1.0, scale_loss_weight=1.0):
super().__init__()
@@ -138,5 +258,8 @@ class HomographyLoss6(nn.Module):
}
HomographyLoss = HomographyLoss6
def count_parameters(model):
return sum(p.numel() for p in model.parameters())

View File

@@ -12,12 +12,12 @@ from .utils import config
class HomographyTrainer:
def __init__(self, model, train_loader, val_loader, device):
def __init__(self, model, train_loader, val_loader, device, criterion):
self.model = model.to(device)
self.train_loader = train_loader
self.val_loader = val_loader
self.device = device
self.criterion = HomographyLoss6()
self.criterion = criterion
self.optimizer = optim.Adam(model.parameters(), lr=config["learning_rate"], weight_decay=1e-4)
self.writer = None
self.best_val_loss = float("inf")