|
from .diffusion import DDIMSampler, DDPMSampler, GaussianDiffusionTrainer |
|
from .denoiser import OneDimCNN |
|
from torch.nn import functional as F |
|
from abc import abstractmethod |
|
from torch import nn |
|
import torch |
|
|
|
|
|
|
|
|
|
class PDiff(nn.Module): |
|
config = {} |
|
|
|
def __init__(self, sequence_length): |
|
super().__init__() |
|
self.sequence_length = sequence_length |
|
self.net = OneDimCNN( |
|
layer_channels=self.config["layer_channels"], |
|
model_dim=self.config["model_dim"], |
|
kernel_size=self.config["kernel_size"], |
|
) |
|
self.diffusion_trainer = GaussianDiffusionTrainer( |
|
model=self.net, |
|
beta=self.config["beta"], |
|
T=self.config["T"] |
|
) |
|
self.diffusion_sampler = self.config["sample_mode"]( |
|
model=self.net, |
|
beta=self.config["beta"], |
|
T=self.config["T"] |
|
) |
|
|
|
def forward(self, x=None, c=0., **kwargs): |
|
if kwargs.get("sample"): |
|
del kwargs["sample"] |
|
return self.sample(x, c, **kwargs) |
|
x = x.view(-1, x.size(-1)) |
|
loss = self.diffusion_trainer(x, c, **kwargs) |
|
return loss |
|
|
|
@torch.no_grad() |
|
def sample(self, x=None, c=0., **kwargs): |
|
if x is None: |
|
x = torch.randn((1, self.config["model_dim"]), device=self.device) |
|
x_shape = x.shape |
|
x = x.view(-1, x.size(-1)) |
|
return self.diffusion_sampler(x, c, **kwargs).view(x_shape) |
|
|
|
@property |
|
def device(self): |
|
return next(self.parameters()).device |
|
|
|
|
|
|
|
|
|
class OneDimVAE(nn.Module): |
|
def __init__(self, d_model, d_latent, sequence_length, kernel_size=7, divide_slice_length=64): |
|
super(OneDimVAE, self).__init__() |
|
self.d_model = d_model.copy() |
|
self.d_latent = d_latent |
|
|
|
sequence_length = (sequence_length // divide_slice_length + 1) * divide_slice_length \ |
|
if sequence_length % divide_slice_length != 0 else sequence_length |
|
assert sequence_length % int(2 ** len(d_model)) == 0, \ |
|
f"Please set divide_slice_length to {int(2 ** len(d_model))}." |
|
self.last_length = sequence_length // int(2 ** len(d_model)) |
|
|
|
|
|
modules = [] |
|
in_dim = 1 |
|
for h_dim in d_model: |
|
modules.append(nn.Sequential( |
|
nn.Conv1d(in_dim, h_dim, kernel_size, 2, kernel_size//2), |
|
nn.BatchNorm1d(h_dim), |
|
nn.LeakyReLU() |
|
)) |
|
in_dim = h_dim |
|
self.encoder = nn.Sequential(*modules) |
|
self.to_latent = nn.Linear(self.last_length * d_model[-1], d_latent) |
|
self.fc_mu = nn.Linear(d_latent, d_latent) |
|
self.fc_var = nn.Linear(d_latent, d_latent) |
|
|
|
|
|
modules = [] |
|
self.to_decode = nn.Linear(d_latent, self.last_length * d_model[-1]) |
|
d_model.reverse() |
|
for i in range(len(d_model) - 1): |
|
modules.append(nn.Sequential( |
|
nn.ConvTranspose1d(d_model[i], d_model[i+1], kernel_size, 2, kernel_size//2, output_padding=1), |
|
nn.BatchNorm1d(d_model[i + 1]), |
|
nn.ELU(), |
|
)) |
|
self.decoder = nn.Sequential(*modules) |
|
self.final_layer = nn.Sequential( |
|
nn.ConvTranspose1d(d_model[-1], d_model[-1], kernel_size, 2, kernel_size//2, output_padding=1), |
|
nn.BatchNorm1d(d_model[-1]), |
|
nn.ELU(), |
|
nn.Conv1d(d_model[-1], 1, kernel_size, 1, kernel_size//2), |
|
) |
|
|
|
def encode(self, input, **kwargs): |
|
|
|
|
|
input = input[:, None, :] |
|
result = self.encoder(input) |
|
|
|
result = torch.flatten(result, start_dim=1) |
|
result = self.to_latent(result) |
|
mu = self.fc_mu(result) |
|
log_var = self.fc_var(result) |
|
return mu, log_var |
|
|
|
def decode(self, z, **kwargs): |
|
|
|
result = self.to_decode(z) |
|
result = result.view(-1, self.d_model[-1], self.last_length) |
|
result = self.decoder(result) |
|
result = self.final_layer(result) |
|
assert result.shape[1] == 1, f"{result.shape}" |
|
return result[:, 0, :] |
|
|
|
def reparameterize(self, mu, log_var, **kwargs): |
|
if kwargs.get("use_var"): |
|
std = torch.exp(0.5 * log_var) |
|
eps = torch.randn_like(std) |
|
if kwargs.get("manual_std") is not None: |
|
std = kwargs.get("manual_std") |
|
return eps * std + mu |
|
else: |
|
return mu |
|
|
|
def encode_decode(self, input, **kwargs): |
|
mu, log_var = self.encode(input) |
|
z = self.reparameterize(mu, log_var, **kwargs) |
|
recons = self.decode(z) |
|
return recons, input, mu, log_var |
|
|
|
def forward(self, x, **kwargs): |
|
recons, input, mu, log_var = self.encode_decode(input=x, **kwargs) |
|
recons_loss = F.mse_loss(recons, input) |
|
if kwargs.get("use_var"): |
|
kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim=1), dim=0) |
|
loss = recons_loss + kwargs['kld_weight'] * kld_loss |
|
else: |
|
loss = recons_loss |
|
return loss |
|
|