clear SiaN
This commit is contained in:
@@ -38,8 +38,7 @@ class UNetUpBlock(nn.Module):
|
||||
self.dropout = None
|
||||
|
||||
def forward(self, x, skip_input):
|
||||
print(x.shape)
|
||||
print(skip_input.shape)
|
||||
|
||||
x = self.upconv(x)
|
||||
# Pad if needed to match skip connection size
|
||||
if x.shape != skip_input.shape:
|
||||
@@ -75,14 +74,14 @@ class GeneratorUNet(nn.Module):
|
||||
nn.ReLU(inplace=True),
|
||||
)
|
||||
|
||||
# Upsampling
|
||||
self.up1 = UNetUpBlock(1024, 512, dropout=0.5)
|
||||
self.up2 = UNetUpBlock(1024, 512, dropout=0.5)
|
||||
self.up3 = UNetUpBlock(1024, 512, dropout=0.5)
|
||||
self.up4 = UNetUpBlock(1024, 512)
|
||||
self.up5 = UNetUpBlock(1024, 256)
|
||||
self.up6 = UNetUpBlock(512, 128)
|
||||
self.up7 = UNetUpBlock(256, 64)
|
||||
# Upsampling - input channels from previous layer, output before concat
|
||||
self.up1 = UNetUpBlock(512, 512, dropout=0.5) # in: 512 (bottleneck) -> out: 512, concat with d7 (512) = 1024
|
||||
self.up2 = UNetUpBlock(1024, 512, dropout=0.5) # in: 1024 -> out: 512, concat with d6 (512) = 1024
|
||||
self.up3 = UNetUpBlock(1024, 512, dropout=0.5) # in: 1024 -> out: 512, concat with d5 (512) = 1024
|
||||
self.up4 = UNetUpBlock(1024, 512) # in: 1024 -> out: 512, concat with d4 (512) = 1024
|
||||
self.up5 = UNetUpBlock(1024, 256) # in: 1024 -> out: 256, concat with d3 (256) = 512
|
||||
self.up6 = UNetUpBlock(512, 128) # in: 512 -> out: 128, concat with d2 (128) = 256
|
||||
self.up7 = UNetUpBlock(256, 64) # in: 256 -> out: 64, concat with d1 (64) = 128
|
||||
|
||||
# Final
|
||||
self.final = nn.Sequential(
|
||||
|
||||
Reference in New Issue
Block a user