fclong's picture
Upload 396 files
8ebda9e
raw
history blame
5.63 kB
# coding=utf-8
# Copyright 2022 IDEA-CCNL The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch Della model. """
import torch
import torch.nn.functional as F
from torch.distributions import Bernoulli
def enforce_repetition_penalty(lprobs, prev_output_tokens, repetition_penalty=1.5):
"""repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858). """
for i in range(len(prev_output_tokens)):
for previous_token in set(prev_output_tokens[i]):
# if score < 0 then repetition penalty has to multiplied to reduce the previous token probability
if lprobs[i, previous_token] < 0:
lprobs[i, previous_token] *= repetition_penalty
else:
lprobs[i, previous_token] /= repetition_penalty
def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
""" Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
Args:
logits: logits distribution shape (vocabulary size)
top_k > 0: keep only top k tokens with highest probability (top-k filtering).
top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
"""
# assert logits.dim() == 1# batch size 1 for now - could be updated for more but the code would be less clear
top_k = min(top_k, logits.size(-1)) # Safety check
if top_k > 0:
# Remove all tokens with a probability less than the last token of the top-k
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits[indices_to_remove] = filter_value
if top_p > 0.0:
sorted_logits, sorted_indices = torch.sort(logits, dim=-1, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
# Remove tokens with cumulative probability above the threshold
sorted_indices_to_remove = cumulative_probs > top_p
# Shift the indices to the right to keep also the first token above the threshold
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
for i in range(sorted_indices.size()[0]):
indices_to_remove = sorted_indices[i][sorted_indices_to_remove[i]]
logits[i][indices_to_remove] = filter_value
# indices_to_remove = sorted_indices[sorted_indices_to_remove]
# logits[indices_to_remove] = filter_value
return logits
def word_drop(x, p, unk_token):
x_ = x.detach().clone()
mask = Bernoulli(1. - p).sample(x.shape)
x_[mask == 0] = unk_token
return x_
def log_sum_exp(value, dim=None, keepdim=False):
"""Numerically stable implementation of the operation
value.exp().sum(dim, keepdim).log()
"""
if dim is not None:
m, _ = torch.max(value, dim=dim, keepdim=True)
value0 = value - m
if keepdim is False:
m = m.squeeze(dim)
return m + torch.log(torch.sum(torch.exp(value0), dim=dim, keepdim=keepdim))
else:
m = torch.max(value)
sum_exp = torch.sum(torch.exp(value - m))
return m + torch.log(sum_exp)
def connect(mean, logvar, nsamples=1, sample=True, clip=False, min_clip_val=-1., beta_logvar=1.):
"""
Returns: Tensor1, Tensor2
Tensor1: the tensor latent z with shape [batch, nsamples, nz]
"""
# (batch, nsamples, nz)
if sample:
if clip:
# NOTE: clip the logvar here to see if we can force z to be more distant
logvar = torch.clip(logvar, min=min_clip_val)
z = reparameterize(mean, logvar, nsamples, beta_logvar)
else:
batch_size, nz = mean.size()
z = mean.unsqueeze(1).expand(batch_size, nsamples, nz)
if nsamples == 1:
z = z.squeeze(dim=1)
return z
def reparameterize(mu, logvar, nsamples=1, beta_logvar=1.):
"""sample from posterior Gaussian family
Args:
mu: Tensor
Mean of gaussian distribution with shape (batch, nz)
logvar: Tensor
logvar of gaussian distibution with shape (batch, nz)
Returns: Tensor
Sampled z with shape (batch, nsamples, nz)
"""
batch_size, nz = mu.size()
std = logvar.mul(0.5).exp().mul(beta_logvar)
mu_expd = mu.unsqueeze(1).expand(batch_size, nsamples, nz)
std_expd = std.unsqueeze(1).expand(batch_size, nsamples, nz)
eps = torch.zeros_like(std_expd).normal_()
return mu_expd + torch.mul(eps, std_expd)
def compute_kl_loss(mean1, logvar1, mean2, logvar2):
'''adapted from adaVAE implementation https://github.com/ImKeTT/adavae/blob/main/src/adapters/vae.py#L1627'''
exponential = logvar1 - logvar2 - torch.pow(mean1 - mean2, 2) / logvar2.exp() - torch.exp(logvar1 - logvar2) + 1
result = -0.5 * torch.sum(exponential, tuple(range(1, len(exponential.shape))))
return result