clear SiaN

This commit is contained in:
2026-04-04 17:50:10 +03:00
parent 15cad7fa65
commit 702c53caac
12 changed files with 806 additions and 5391 deletions

View File

@@ -38,8 +38,7 @@ class UNetUpBlock(nn.Module):
self.dropout = None
def forward(self, x, skip_input):
print(x.shape)
print(skip_input.shape)
x = self.upconv(x)
# Pad if needed to match skip connection size
if x.shape != skip_input.shape:
@@ -75,14 +74,14 @@ class GeneratorUNet(nn.Module):
nn.ReLU(inplace=True),
)
# Upsampling
self.up1 = UNetUpBlock(1024, 512, dropout=0.5)
self.up2 = UNetUpBlock(1024, 512, dropout=0.5)
self.up3 = UNetUpBlock(1024, 512, dropout=0.5)
self.up4 = UNetUpBlock(1024, 512)
self.up5 = UNetUpBlock(1024, 256)
self.up6 = UNetUpBlock(512, 128)
self.up7 = UNetUpBlock(256, 64)
# Upsampling - input channels from previous layer, output before concat
self.up1 = UNetUpBlock(512, 512, dropout=0.5) # in: 512 (bottleneck) -> out: 512, concat with d7 (512) = 1024
self.up2 = UNetUpBlock(1024, 512, dropout=0.5) # in: 1024 -> out: 512, concat with d6 (512) = 1024
self.up3 = UNetUpBlock(1024, 512, dropout=0.5) # in: 1024 -> out: 512, concat with d5 (512) = 1024
self.up4 = UNetUpBlock(1024, 512) # in: 1024 -> out: 512, concat with d4 (512) = 1024
self.up5 = UNetUpBlock(1024, 256) # in: 1024 -> out: 256, concat with d3 (256) = 512
self.up6 = UNetUpBlock(512, 128) # in: 512 -> out: 128, concat with d2 (128) = 256
self.up7 = UNetUpBlock(256, 64) # in: 256 -> out: 64, concat with d1 (64) = 128
# Final
self.final = nn.Sequential(