Vizuara commited on
Commit
9bfbfd8
·
verified ·
1 Parent(s): 34c1b45

Create unet_model.py

Browse files
Files changed (1) hide show
  1. unet_model.py +73 -0
unet_model.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ class UNet(nn.Module):
5
+ def __init__(self):
6
+ super(UNet, self).__init__()
7
+
8
+ def conv_block(in_channels, out_channels):
9
+ return nn.Sequential(
10
+ nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
11
+ nn.ReLU(inplace=True),
12
+ nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
13
+ nn.ReLU(inplace=True),
14
+ )
15
+
16
+ # Encoder
17
+ self.enc1 = conv_block(3, 64)
18
+ self.enc2 = conv_block(64, 128)
19
+ self.enc3 = conv_block(128, 256)
20
+ self.enc4 = conv_block(256, 512)
21
+
22
+ self.pool = nn.MaxPool2d(2)
23
+
24
+ # Bottleneck
25
+ self.bottleneck = conv_block(512, 1024)
26
+
27
+ # Decoder
28
+ self.upconv4 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
29
+ self.dec4 = conv_block(1024, 512)
30
+
31
+ self.upconv3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
32
+ self.dec3 = conv_block(512, 256)
33
+
34
+ self.upconv2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
35
+ self.dec2 = conv_block(256, 128)
36
+
37
+ self.upconv1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
38
+ self.dec1 = conv_block(128, 64)
39
+
40
+ self.conv_last = nn.Conv2d(64, 1, kernel_size=1)
41
+
42
+ def forward(self, x):
43
+ c1 = self.enc1(x)
44
+ p1 = self.pool(c1)
45
+
46
+ c2 = self.enc2(p1)
47
+ p2 = self.pool(c2)
48
+
49
+ c3 = self.enc3(p2)
50
+ p3 = self.pool(c3)
51
+
52
+ c4 = self.enc4(p3)
53
+ p4 = self.pool(c4)
54
+
55
+ bottleneck = self.bottleneck(p4)
56
+
57
+ u4 = self.upconv4(bottleneck)
58
+ u4 = torch.cat([u4, c4], dim=1)
59
+ d4 = self.dec4(u4)
60
+
61
+ u3 = self.upconv3(d4)
62
+ u3 = torch.cat([u3, c3], dim=1)
63
+ d3 = self.dec3(u3)
64
+
65
+ u2 = self.upconv2(d3)
66
+ u2 = torch.cat([u2, c2], dim=1)
67
+ d2 = self.dec2(u2)
68
+
69
+ u1 = self.upconv1(d2)
70
+ u1 = torch.cat([u1, c1], dim=1)
71
+ d1 = self.dec1(u1)
72
+
73
+ return torch.sigmoid(self.conv_last(d1)) # sigmoid kept (matches BCELoss training)