|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" PyTorch Della model. """ |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
from torch.distributions import Bernoulli |
|
|
|
|
|
def enforce_repetition_penalty(lprobs, prev_output_tokens, repetition_penalty=1.5): |
|
"""repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858). """ |
|
for i in range(len(prev_output_tokens)): |
|
for previous_token in set(prev_output_tokens[i]): |
|
|
|
if lprobs[i, previous_token] < 0: |
|
lprobs[i, previous_token] *= repetition_penalty |
|
else: |
|
lprobs[i, previous_token] /= repetition_penalty |
|
|
|
|
|
def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')): |
|
""" Filter a distribution of logits using top-k and/or nucleus (top-p) filtering |
|
Args: |
|
logits: logits distribution shape (vocabulary size) |
|
top_k > 0: keep only top k tokens with highest probability (top-k filtering). |
|
top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering). |
|
Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751) |
|
From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317 |
|
""" |
|
|
|
top_k = min(top_k, logits.size(-1)) |
|
if top_k > 0: |
|
|
|
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] |
|
logits[indices_to_remove] = filter_value |
|
|
|
if top_p > 0.0: |
|
sorted_logits, sorted_indices = torch.sort(logits, dim=-1, descending=True) |
|
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) |
|
|
|
|
|
sorted_indices_to_remove = cumulative_probs > top_p |
|
|
|
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() |
|
sorted_indices_to_remove[..., 0] = 0 |
|
|
|
for i in range(sorted_indices.size()[0]): |
|
indices_to_remove = sorted_indices[i][sorted_indices_to_remove[i]] |
|
logits[i][indices_to_remove] = filter_value |
|
|
|
|
|
return logits |
|
|
|
|
|
def word_drop(x, p, unk_token): |
|
x_ = x.detach().clone() |
|
mask = Bernoulli(1. - p).sample(x.shape) |
|
x_[mask == 0] = unk_token |
|
return x_ |
|
|
|
|
|
def log_sum_exp(value, dim=None, keepdim=False): |
|
"""Numerically stable implementation of the operation |
|
value.exp().sum(dim, keepdim).log() |
|
""" |
|
if dim is not None: |
|
m, _ = torch.max(value, dim=dim, keepdim=True) |
|
value0 = value - m |
|
if keepdim is False: |
|
m = m.squeeze(dim) |
|
return m + torch.log(torch.sum(torch.exp(value0), dim=dim, keepdim=keepdim)) |
|
else: |
|
m = torch.max(value) |
|
sum_exp = torch.sum(torch.exp(value - m)) |
|
return m + torch.log(sum_exp) |
|
|
|
|
|
def connect(mean, logvar, nsamples=1, sample=True, clip=False, min_clip_val=-1., beta_logvar=1.): |
|
""" |
|
Returns: Tensor1, Tensor2 |
|
Tensor1: the tensor latent z with shape [batch, nsamples, nz] |
|
""" |
|
|
|
if sample: |
|
if clip: |
|
|
|
logvar = torch.clip(logvar, min=min_clip_val) |
|
z = reparameterize(mean, logvar, nsamples, beta_logvar) |
|
else: |
|
batch_size, nz = mean.size() |
|
z = mean.unsqueeze(1).expand(batch_size, nsamples, nz) |
|
if nsamples == 1: |
|
z = z.squeeze(dim=1) |
|
return z |
|
|
|
|
|
def reparameterize(mu, logvar, nsamples=1, beta_logvar=1.): |
|
"""sample from posterior Gaussian family |
|
Args: |
|
mu: Tensor |
|
Mean of gaussian distribution with shape (batch, nz) |
|
logvar: Tensor |
|
logvar of gaussian distibution with shape (batch, nz) |
|
Returns: Tensor |
|
Sampled z with shape (batch, nsamples, nz) |
|
""" |
|
batch_size, nz = mu.size() |
|
std = logvar.mul(0.5).exp().mul(beta_logvar) |
|
|
|
mu_expd = mu.unsqueeze(1).expand(batch_size, nsamples, nz) |
|
std_expd = std.unsqueeze(1).expand(batch_size, nsamples, nz) |
|
|
|
eps = torch.zeros_like(std_expd).normal_() |
|
|
|
return mu_expd + torch.mul(eps, std_expd) |
|
|
|
|
|
def compute_kl_loss(mean1, logvar1, mean2, logvar2): |
|
'''adapted from adaVAE implementation https://github.com/ImKeTT/adavae/blob/main/src/adapters/vae.py#L1627''' |
|
exponential = logvar1 - logvar2 - torch.pow(mean1 - mean2, 2) / logvar2.exp() - torch.exp(logvar1 - logvar2) + 1 |
|
result = -0.5 * torch.sum(exponential, tuple(range(1, len(exponential.shape)))) |
|
return result |
|
|