Doven
update code.
f7009b3
raw
history blame
5.33 kB
from .diffusion import DDIMSampler, DDPMSampler, GaussianDiffusionTrainer
from .denoiser import OneDimCNN
from torch.nn import functional as F
from abc import abstractmethod
from torch import nn
import torch
class PDiff(nn.Module):
config = {}
def __init__(self, sequence_length):
super().__init__()
self.sequence_length = sequence_length
self.net = OneDimCNN(
layer_channels=self.config["layer_channels"],
model_dim=self.config["model_dim"],
kernel_size=self.config["kernel_size"],
)
self.diffusion_trainer = GaussianDiffusionTrainer(
model=self.net,
beta=self.config["beta"],
T=self.config["T"]
)
self.diffusion_sampler = self.config["sample_mode"](
model=self.net,
beta=self.config["beta"],
T=self.config["T"]
)
def forward(self, x=None, c=0., **kwargs):
if kwargs.get("sample"):
del kwargs["sample"]
return self.sample(x, c, **kwargs)
x = x.view(-1, x.size(-1))
loss = self.diffusion_trainer(x, c, **kwargs)
return loss
@torch.no_grad()
def sample(self, x=None, c=0., **kwargs):
if x is None:
x = torch.randn((1, self.config["model_dim"]), device=self.device)
x_shape = x.shape
x = x.view(-1, x.size(-1))
return self.diffusion_sampler(x, c, **kwargs).view(x_shape)
@property
def device(self):
return next(self.parameters()).device
class OneDimVAE(nn.Module):
def __init__(self, d_model, d_latent, sequence_length, kernel_size=7, divide_slice_length=64):
super(OneDimVAE, self).__init__()
self.d_model = d_model.copy()
self.d_latent = d_latent
# confirm self.last_length
sequence_length = (sequence_length // divide_slice_length + 1) * divide_slice_length \
if sequence_length % divide_slice_length != 0 else sequence_length
assert sequence_length % int(2 ** len(d_model)) == 0, \
f"Please set divide_slice_length to {int(2 ** len(d_model))}."
self.last_length = sequence_length // int(2 ** len(d_model))
# Build Encoder
modules = []
in_dim = 1
for h_dim in d_model:
modules.append(nn.Sequential(
nn.Conv1d(in_dim, h_dim, kernel_size, 2, kernel_size//2),
nn.BatchNorm1d(h_dim),
nn.LeakyReLU()
))
in_dim = h_dim
self.encoder = nn.Sequential(*modules)
self.to_latent = nn.Linear(self.last_length * d_model[-1], d_latent)
self.fc_mu = nn.Linear(d_latent, d_latent)
self.fc_var = nn.Linear(d_latent, d_latent)
# Build Decoder
modules = []
self.to_decode = nn.Linear(d_latent, self.last_length * d_model[-1])
d_model.reverse()
for i in range(len(d_model) - 1):
modules.append(nn.Sequential(
nn.ConvTranspose1d(d_model[i], d_model[i+1], kernel_size, 2, kernel_size//2, output_padding=1),
nn.BatchNorm1d(d_model[i + 1]),
nn.ELU(),
))
self.decoder = nn.Sequential(*modules)
self.final_layer = nn.Sequential(
nn.ConvTranspose1d(d_model[-1], d_model[-1], kernel_size, 2, kernel_size//2, output_padding=1),
nn.BatchNorm1d(d_model[-1]),
nn.ELU(),
nn.Conv1d(d_model[-1], 1, kernel_size, 1, kernel_size//2),
)
def encode(self, input, **kwargs):
# print(input.shape)
# assert input.shape == [batch_size, num_parameters]
input = input[:, None, :]
result = self.encoder(input)
# print(result.shape)
result = torch.flatten(result, start_dim=1)
result = self.to_latent(result)
mu = self.fc_mu(result)
log_var = self.fc_var(result)
return mu, log_var
def decode(self, z, **kwargs):
# z.shape == [batch_size, d_latent]
result = self.to_decode(z)
result = result.view(-1, self.d_model[-1], self.last_length)
result = self.decoder(result)
result = self.final_layer(result)
assert result.shape[1] == 1, f"{result.shape}"
return result[:, 0, :]
def reparameterize(self, mu, log_var, **kwargs):
if kwargs.get("use_var"):
std = torch.exp(0.5 * log_var)
eps = torch.randn_like(std)
if kwargs.get("manual_std") is not None:
std = kwargs.get("manual_std")
return eps * std + mu
else: # not use var
return mu
def encode_decode(self, input, **kwargs):
mu, log_var = self.encode(input)
z = self.reparameterize(mu, log_var, **kwargs)
recons = self.decode(z)
return recons, input, mu, log_var
def forward(self, x, **kwargs):
recons, input, mu, log_var = self.encode_decode(input=x, **kwargs)
recons_loss = F.mse_loss(recons, input)
if kwargs.get("use_var"):
kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim=1), dim=0)
loss = recons_loss + kwargs['kld_weight'] * kld_loss
else: # not use var
loss = recons_loss
return loss