import torch | |
import torch.nn as nn | |
class AutoEncoder(nn.Module): | |
def __init__(self): | |
super(AutoEncoder, self).__init__() | |
self.encoder = nn.Linear(343, 410) | |
self.sparsify = nn.Sigmoid() | |
self.decoder = nn.Linear(410, 343) | |
def forward(self, out): | |
out = out.view(-1, 343) | |
out = self.encoder(out) | |
out = self.sparsify(out) | |
s_ = out | |
out = self.decoder(out) | |
return out, s_ | |