Spaces:
Running
on
T4
Running
on
T4
| import torch | |
| from transformers import LogitsWarper | |
| class TypicalLogitsWarper(LogitsWarper): | |
| def __init__(self, mass: float = 0.9, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1): | |
| self.filter_value = filter_value | |
| self.mass = mass | |
| self.min_tokens_to_keep = min_tokens_to_keep | |
| def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: | |
| # calculate entropy | |
| normalized = torch.nn.functional.log_softmax(scores, dim=-1) | |
| p = torch.exp(normalized) | |
| ent = -(normalized * p).nansum(-1, keepdim=True) | |
| # shift and sort | |
| shifted_scores = torch.abs((-normalized) - ent) | |
| sorted_scores, sorted_indices = torch.sort(shifted_scores, descending=False) | |
| sorted_logits = scores.gather(-1, sorted_indices) | |
| cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1) | |
| # Remove tokens with cumulative mass above the threshold | |
| last_ind = (cumulative_probs < self.mass).sum(dim=1) | |
| last_ind[last_ind < 0] = 0 | |
| sorted_indices_to_remove = sorted_scores > sorted_scores.gather(1, last_ind.view(-1, 1)) | |
| if self.min_tokens_to_keep > 1: | |
| # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below) | |
| sorted_indices_to_remove[..., : self.min_tokens_to_keep] = 0 | |
| indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) | |
| scores = scores.masked_fill(indices_to_remove, self.filter_value) | |
| return scores |