Upload dcgan_64.py
Browse files- dcgan_64.py +140 -0
    	
        dcgan_64.py
    ADDED
    
    | @@ -0,0 +1,140 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch.nn as nn
         | 
| 2 | 
            +
             | 
| 3 | 
            +
             | 
| 4 | 
            +
            class dcgan_conv(nn.Module):
         | 
| 5 | 
            +
                def __init__(self, nin, nout):
         | 
| 6 | 
            +
                    super(dcgan_conv, self).__init__()
         | 
| 7 | 
            +
                    self.main = nn.Sequential(nn.Conv2d(nin, nout, 4, 2, 1), nn.BatchNorm2d(nout), nn.LeakyReLU(0.2, inplace=True))
         | 
| 8 | 
            +
             | 
| 9 | 
            +
                def forward(self, input):
         | 
| 10 | 
            +
                    return self.main(input)
         | 
| 11 | 
            +
             | 
| 12 | 
            +
             | 
| 13 | 
            +
            class dcgan_upconv(nn.Module):
         | 
| 14 | 
            +
                def __init__(self, nin, nout):
         | 
| 15 | 
            +
                    super(dcgan_upconv, self).__init__()
         | 
| 16 | 
            +
                    self.main = nn.Sequential(nn.ConvTranspose2d(nin, nout, 4, 2, 1), nn.BatchNorm2d(nout), nn.LeakyReLU(0.2, inplace=True))
         | 
| 17 | 
            +
             | 
| 18 | 
            +
                def forward(self, input):
         | 
| 19 | 
            +
                    return self.main(input)
         | 
| 20 | 
            +
             | 
| 21 | 
            +
             | 
| 22 | 
            +
            class encoder(nn.Module):
         | 
| 23 | 
            +
                def __init__(self, dim, nc=1):
         | 
| 24 | 
            +
                    super(encoder, self).__init__()
         | 
| 25 | 
            +
                    self.dim = dim
         | 
| 26 | 
            +
                    nf = 64
         | 
| 27 | 
            +
                    self.c1 = dcgan_conv(nc, nf)
         | 
| 28 | 
            +
                    self.c2 = dcgan_conv(nf, nf * 2)
         | 
| 29 | 
            +
                    self.c3 = dcgan_conv(nf * 2, nf * 4)
         | 
| 30 | 
            +
                    self.c4 = dcgan_conv(nf * 4, nf * 8)
         | 
| 31 | 
            +
                    self.c5 = nn.Sequential(nn.Conv2d(nf * 8, dim, 4, 1, 0), nn.BatchNorm2d(dim), nn.Tanh())
         | 
| 32 | 
            +
             | 
| 33 | 
            +
                def forward(self, input):
         | 
| 34 | 
            +
                    h1 = self.c1(input)
         | 
| 35 | 
            +
                    h2 = self.c2(h1)
         | 
| 36 | 
            +
                    h3 = self.c3(h2)
         | 
| 37 | 
            +
                    h4 = self.c4(h3)
         | 
| 38 | 
            +
                    h5 = self.c5(h4)
         | 
| 39 | 
            +
                    return h5.view(-1, self.dim), [h1, h2, h3, h4]
         | 
| 40 | 
            +
             | 
| 41 | 
            +
             | 
| 42 | 
            +
            class decoder_convT(nn.Module):
         | 
| 43 | 
            +
                def __init__(self, dim, nc=1):
         | 
| 44 | 
            +
                    super(decoder_convT, self).__init__()
         | 
| 45 | 
            +
                    self.dim = dim
         | 
| 46 | 
            +
                    nf = 64
         | 
| 47 | 
            +
                    self.upc1 = nn.Sequential(
         | 
| 48 | 
            +
                        nn.ConvTranspose2d(dim, nf * 8, 4, 1, 0),
         | 
| 49 | 
            +
                        nn.BatchNorm2d(nf * 8),
         | 
| 50 | 
            +
                        nn.LeakyReLU(0.2, inplace=True)
         | 
| 51 | 
            +
                    )
         | 
| 52 | 
            +
                    self.upc2 = dcgan_upconv(nf * 8, nf * 4)
         | 
| 53 | 
            +
                    self.upc3 = dcgan_upconv(nf * 4, nf * 2)
         | 
| 54 | 
            +
                    self.upc4 = dcgan_upconv(nf * 2, nf)
         | 
| 55 | 
            +
                    self.upc5 = nn.Sequential(
         | 
| 56 | 
            +
                        nn.ConvTranspose2d(nf, nc, 4, 2, 1),
         | 
| 57 | 
            +
                        nn.Sigmoid()
         | 
| 58 | 
            +
                    )
         | 
| 59 | 
            +
             | 
| 60 | 
            +
                def forward(self, input):
         | 
| 61 | 
            +
                    d1 = self.upc1(input.view(-1, self.dim, 1, 1))
         | 
| 62 | 
            +
                    d2 = self.upc2(d1)
         | 
| 63 | 
            +
                    d3 = self.upc3(d2)
         | 
| 64 | 
            +
                    d4 = self.upc4(d3)
         | 
| 65 | 
            +
                    output = self.upc5(d4)
         | 
| 66 | 
            +
                    output = output.view(input.shape[0], input.shape[1], output.shape[1], output.shape[2], output.shape[3])
         | 
| 67 | 
            +
                    return output
         | 
| 68 | 
            +
             | 
| 69 | 
            +
             | 
| 70 | 
            +
            class decoder_woSkip(nn.Module):
         | 
| 71 | 
            +
                def __init__(self, dim, nc=1):
         | 
| 72 | 
            +
                    super(decoder_woSkip, self).__init__()
         | 
| 73 | 
            +
                    self.dim = dim
         | 
| 74 | 
            +
                    nf = 64
         | 
| 75 | 
            +
                    self.upc1 = nn.Sequential(
         | 
| 76 | 
            +
                        nn.ConvTranspose2d(dim, nf * 8, 4, 1, 0),
         | 
| 77 | 
            +
                        nn.BatchNorm2d(nf * 8),
         | 
| 78 | 
            +
                        nn.LeakyReLU(0.2, inplace=True)
         | 
| 79 | 
            +
                    )
         | 
| 80 | 
            +
                    self.upc2 = dcgan_upconv(nf * 8, nf * 4)
         | 
| 81 | 
            +
                    self.upc3 = dcgan_upconv(nf * 4, nf * 2)
         | 
| 82 | 
            +
                    self.upc4 = dcgan_upconv(nf * 2, nf)
         | 
| 83 | 
            +
                    self.upc5 = nn.Sequential(
         | 
| 84 | 
            +
                        nn.ConvTranspose2d(nf, nc, 4, 2, 1),
         | 
| 85 | 
            +
                        nn.Sigmoid()
         | 
| 86 | 
            +
                    )
         | 
| 87 | 
            +
             | 
| 88 | 
            +
                def forward(self, input):
         | 
| 89 | 
            +
                    d1 = self.upc1(input.view(-1, self.dim, 1, 1))
         | 
| 90 | 
            +
                    d2 = self.upc2(d1)
         | 
| 91 | 
            +
                    d3 = self.upc3(d2)
         | 
| 92 | 
            +
                    d4 = self.upc4(d3)
         | 
| 93 | 
            +
                    output = self.upc5(d4)
         | 
| 94 | 
            +
                    output = output.view(input.shape[0], input.shape[1], output.shape[1], output.shape[2], output.shape[3])
         | 
| 95 | 
            +
                    return output
         | 
| 96 | 
            +
             | 
| 97 | 
            +
             | 
| 98 | 
            +
            """
         | 
| 99 | 
            +
            # Using Convolution and up_resize as the block to up-sample
         | 
| 100 | 
            +
            """
         | 
| 101 | 
            +
            import torch.nn.functional as F
         | 
| 102 | 
            +
            class upconv(nn.Module):
         | 
| 103 | 
            +
                def __init__(self, nc_in, nc_out):
         | 
| 104 | 
            +
                    super().__init__()
         | 
| 105 | 
            +
                    self.conv = nn.Conv2d(nc_in, nc_out, 3, 1, 1)
         | 
| 106 | 
            +
                    self.norm = nn.BatchNorm2d(nc_out)
         | 
| 107 | 
            +
             | 
| 108 | 
            +
                def forward(self, input):
         | 
| 109 | 
            +
                    out = F.interpolate(input, scale_factor=2, mode='bilinear', align_corners=False)
         | 
| 110 | 
            +
                    return F.relu(self.norm(self.conv(out)))
         | 
| 111 | 
            +
             | 
| 112 | 
            +
            class decoder_conv(nn.Module):
         | 
| 113 | 
            +
                def __init__(self, dim, nc=1):
         | 
| 114 | 
            +
                    super(decoder_conv, self).__init__()
         | 
| 115 | 
            +
                    self.dim = dim
         | 
| 116 | 
            +
                    nf = 64
         | 
| 117 | 
            +
             | 
| 118 | 
            +
                    self.main = nn.Sequential(
         | 
| 119 | 
            +
                        nn.ConvTranspose2d(dim, nf * 8, 4, 1, 0),
         | 
| 120 | 
            +
                        nn.BatchNorm2d(nf * 8),
         | 
| 121 | 
            +
                        nn.ReLU(),
         | 
| 122 | 
            +
                        # state size. (nf*8) x 4 x 4
         | 
| 123 | 
            +
                        upconv(nf * 8, nf * 4),
         | 
| 124 | 
            +
                        # state size. (nf*4) x 8 x 8
         | 
| 125 | 
            +
                        upconv(nf * 4, nf * 2),
         | 
| 126 | 
            +
                        # state size. (nf*2) x 16 x 16
         | 
| 127 | 
            +
                        upconv(nf * 2, nf * 2),
         | 
| 128 | 
            +
                        # state size. (nf*2) x 32 x 32
         | 
| 129 | 
            +
                        upconv(nf * 2, nf),
         | 
| 130 | 
            +
                        # state size. (nf) x 64 x 64
         | 
| 131 | 
            +
                        nn.Conv2d(nf, nc, 1, 1, 0),
         | 
| 132 | 
            +
                        nn.Sigmoid()
         | 
| 133 | 
            +
                    )
         | 
| 134 | 
            +
             | 
| 135 | 
            +
             | 
| 136 | 
            +
                def forward(self, input):
         | 
| 137 | 
            +
                    output = self.main(input.view(-1, self.dim, 1, 1))
         | 
| 138 | 
            +
                    output = output.view(input.shape[0], input.shape[1], output.shape[1], output.shape[2], output.shape[3])
         | 
| 139 | 
            +
             | 
| 140 | 
            +
                    return output
         |