Vizuara commited on
Commit
45cddc9
·
verified ·
1 Parent(s): aadf139

Update unet_model.py

Browse files
Files changed (1) hide show
  1. unet_model.py +4 -1
unet_model.py CHANGED
@@ -13,6 +13,7 @@ class UNet(nn.Module):
13
  nn.ReLU(inplace=True),
14
  )
15
 
 
16
  self.enc1 = conv_block(3, 64)
17
  self.enc2 = conv_block(64, 128)
18
  self.enc3 = conv_block(128, 256)
@@ -20,8 +21,10 @@ class UNet(nn.Module):
20
 
21
  self.pool = nn.MaxPool2d(2)
22
 
 
23
  self.bottleneck = conv_block(512, 1024)
24
 
 
25
  self.upconv4 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
26
  self.dec4 = conv_block(1024, 512)
27
 
@@ -67,4 +70,4 @@ class UNet(nn.Module):
67
  u1 = torch.cat([u1, c1], dim=1)
68
  d1 = self.dec1(u1)
69
 
70
- return torch.sigmoid(self.conv_last(d1))
 
13
  nn.ReLU(inplace=True),
14
  )
15
 
16
+ # Encoder
17
  self.enc1 = conv_block(3, 64)
18
  self.enc2 = conv_block(64, 128)
19
  self.enc3 = conv_block(128, 256)
 
21
 
22
  self.pool = nn.MaxPool2d(2)
23
 
24
+ # Bottleneck
25
  self.bottleneck = conv_block(512, 1024)
26
 
27
+ # Decoder
28
  self.upconv4 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
29
  self.dec4 = conv_block(1024, 512)
30
 
 
70
  u1 = torch.cat([u1, c1], dim=1)
71
  d1 = self.dec1(u1)
72
 
73
+ return torch.sigmoid(self.conv_last(d1)) # sigmoid kept (matches BCELoss training)