try custom cnn
This commit is contained in:
File diff suppressed because it is too large
Load Diff
@@ -4,7 +4,7 @@ import os
|
||||
import torch
|
||||
|
||||
from .dataloader import create_data_loaders, get_dataset_info
|
||||
from .model import HomographyCNN6, count_parameters
|
||||
from .model import HomographyCNN6, HomographyHybridCNN, HomographyLoss, count_parameters
|
||||
from .train import HomographyTrainer
|
||||
from .analyze import analyze_training
|
||||
from .utils import config
|
||||
@@ -29,18 +29,24 @@ train_loader, val_loader = create_data_loaders(
|
||||
)
|
||||
logger.info(f"Data loaders created: train={len(train_loader.dataset)}, val={len(val_loader.dataset)}")
|
||||
|
||||
model = HomographyCNN6(
|
||||
# model = HomographyCNN6(
|
||||
# input_channels=3,
|
||||
# backbone_name=config["backbone"],
|
||||
# pretrained=True,
|
||||
# dropout_rate=config["dropout_rate"]
|
||||
# )
|
||||
|
||||
model = HomographyHybridCNN(
|
||||
input_channels=3,
|
||||
backbone_name=config["backbone"],
|
||||
pretrained=True,
|
||||
dropout_rate=config["dropout_rate"]
|
||||
dropout_rate=config["droupout_rate"],
|
||||
)
|
||||
|
||||
logger.info(f"Model created with {count_parameters(model):,} parameters")
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
logger.info(f"Using device: {device}")
|
||||
|
||||
trainer = HomographyTrainer(model, train_loader, val_loader, device)
|
||||
trainer = HomographyTrainer(model, train_loader, val_loader, device, HomographyLoss())
|
||||
logger.info("Starting training...")
|
||||
trainer.train(config["epochs"])
|
||||
logger.info("Training completed")
|
||||
|
||||
@@ -77,6 +77,126 @@ class HomographyCNN6(nn.Module):
|
||||
}
|
||||
|
||||
|
||||
class HomographyHybridCNN(nn.Module):
|
||||
def __init__(self, input_channels=3, use_resnet_layers=2, dropout_rate=0.3):
|
||||
super().__init__()
|
||||
|
||||
if use_resnet_layers == 1:
|
||||
resnet = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
|
||||
self.conv1 = resnet.conv1
|
||||
self.bn1 = resnet.bn1
|
||||
self.relu = resnet.relu
|
||||
self.maxpool = resnet.maxpool
|
||||
conv_out_channels = 64
|
||||
elif use_resnet_layers == 2:
|
||||
resnet = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
|
||||
self.conv1 = resnet.conv1
|
||||
self.bn1 = resnet.bn1
|
||||
self.relu = resnet.relu
|
||||
self.maxpool = resnet.maxpool
|
||||
self.conv2 = resnet.layer1[0].conv1
|
||||
self.bn2 = resnet.layer1[0].bn1
|
||||
self.conv2_2 = resnet.layer1[0].conv2
|
||||
self.bn2_2 = resnet.layer1[0].bn2
|
||||
self.relu2 = resnet.layer1[0].relu
|
||||
self.maxpool2 = resnet.maxpool
|
||||
conv_out_channels = 64
|
||||
else:
|
||||
raise ValueError("use_resnet_layers must be 1 or 2")
|
||||
|
||||
self.use_resnet_layers = use_resnet_layers
|
||||
self.feature_map_size = 64
|
||||
|
||||
self.conv_head = nn.Sequential(
|
||||
nn.Conv2d(conv_out_channels, 128, kernel_size=3, padding=1),
|
||||
nn.BatchNorm2d(128),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(128, 256, kernel_size=3, padding=1),
|
||||
nn.BatchNorm2d(256),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.MaxPool2d(2),
|
||||
)
|
||||
|
||||
self.global_pool = nn.AdaptiveAvgPool2d((1, 1))
|
||||
|
||||
feature_dim = 256 * 4
|
||||
|
||||
self.head = nn.Sequential(
|
||||
nn.Linear(feature_dim, 1024),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Dropout(dropout_rate),
|
||||
nn.Linear(1024, 512),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Dropout(dropout_rate),
|
||||
nn.Linear(512, 256),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Dropout(dropout_rate),
|
||||
nn.Linear(256, 6),
|
||||
)
|
||||
self._init_weights()
|
||||
|
||||
def _init_weights(self):
|
||||
for module in self.head.modules():
|
||||
if isinstance(module, nn.Linear):
|
||||
nn.init.kaiming_normal_(module.weight, mode='fan_in', nonlinearity='relu')
|
||||
if module.bias is not None:
|
||||
nn.init.zeros_(module.bias)
|
||||
|
||||
def forward(self, img1, img2):
|
||||
x1 = self._extract_features(img1)
|
||||
x2 = self._extract_features(img2)
|
||||
|
||||
combined = torch.cat([x1, x2, torch.abs(x1 - x2), x1 * x2], dim=1)
|
||||
output = self.head(combined)
|
||||
|
||||
output = torch.tanh(output)
|
||||
modified = output.clone()
|
||||
modified[:, 2:6] = torch.mul(output[:, 2:6], torch.pi)
|
||||
|
||||
return modified
|
||||
|
||||
def _extract_features(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
x = self.relu(x)
|
||||
x = self.maxpool(x)
|
||||
|
||||
if self.use_resnet_layers >= 2:
|
||||
x = self.conv2(x)
|
||||
x = self.bn2(x)
|
||||
x = self.relu(x)
|
||||
x = self.conv2_2(x)
|
||||
x = self.bn2_2(x)
|
||||
x = self.relu2(x)
|
||||
x = self.maxpool2(x)
|
||||
|
||||
x = self.conv_head(x)
|
||||
x = self.global_pool(x)
|
||||
x = x.view(x.size(0), -1)
|
||||
|
||||
return x
|
||||
|
||||
def decode_output(self, output):
|
||||
tx = output[:, 0]
|
||||
ty = output[:, 1]
|
||||
scale = output[:, 5]
|
||||
angle1 = output[:, 2]
|
||||
angle2 = output[:, 3]
|
||||
angle3 = output[:, 4]
|
||||
return torch.stack([tx, ty, angle1, angle2, angle3, scale], dim=1)
|
||||
|
||||
def get_components(self, output):
|
||||
decoded = self.decode_output(output)
|
||||
return {
|
||||
"tx": decoded[:, 0],
|
||||
"ty": decoded[:, 1],
|
||||
"rx": decoded[:, 2],
|
||||
"ry": decoded[:, 3],
|
||||
"rz": decoded[:, 4],
|
||||
"scale": decoded[:, 5],
|
||||
}
|
||||
|
||||
|
||||
class HomographyLoss6(nn.Module):
|
||||
def __init__(self, angle_loss_weight=1.0, trans_loss_weight=1.0, scale_loss_weight=1.0):
|
||||
super().__init__()
|
||||
@@ -138,5 +258,8 @@ class HomographyLoss6(nn.Module):
|
||||
}
|
||||
|
||||
|
||||
HomographyLoss = HomographyLoss6
|
||||
|
||||
|
||||
def count_parameters(model):
|
||||
return sum(p.numel() for p in model.parameters())
|
||||
|
||||
@@ -12,12 +12,12 @@ from .utils import config
|
||||
|
||||
|
||||
class HomographyTrainer:
|
||||
def __init__(self, model, train_loader, val_loader, device):
|
||||
def __init__(self, model, train_loader, val_loader, device, criterion):
|
||||
self.model = model.to(device)
|
||||
self.train_loader = train_loader
|
||||
self.val_loader = val_loader
|
||||
self.device = device
|
||||
self.criterion = HomographyLoss6()
|
||||
self.criterion = criterion
|
||||
self.optimizer = optim.Adam(model.parameters(), lr=config["learning_rate"], weight_decay=1e-4)
|
||||
self.writer = None
|
||||
self.best_val_loss = float("inf")
|
||||
|
||||
Reference in New Issue
Block a user