Spaces:
Runtime error
Runtime error
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # | |
| # This source code is licensed under the Chameleon License found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import math | |
| import torch | |
| from transformers import LogitsProcessor | |
| class TopPProbabilityProcessor(LogitsProcessor): | |
| # Modified version of TopPLogitsWarper to act on probabilities. | |
| # Changes: | |
| # * filter_value changed from -inf to 0 | |
| # * removed softmax | |
| # * renormalize L1 | |
| def __init__( | |
| self, | |
| top_p: float, | |
| min_tokens_to_keep: int = 1, | |
| ): | |
| top_p = float(top_p) | |
| if top_p < 0 or top_p > 1.0: | |
| raise ValueError(f"`top_p` has to be a float > 0 and < 1, but is {top_p}") | |
| if not isinstance(min_tokens_to_keep, int) or (min_tokens_to_keep < 1): | |
| raise ValueError( | |
| f"`min_tokens_to_keep` has to be a positive integer, but is {min_tokens_to_keep}" | |
| ) | |
| self.top_p = top_p | |
| self.min_tokens_to_keep = min_tokens_to_keep | |
| def __call__( | |
| self, input_ids: torch.LongTensor, probs: torch.FloatTensor | |
| ) -> torch.FloatTensor: | |
| # input_ids.shape=[batch, seq-len] | |
| # probs.shape=[batch, vocab] | |
| sorted_probs, sorted_indices = torch.sort(probs, descending=False) | |
| cumulative_probs = sorted_probs.cumsum(dim=-1) | |
| # Remove tokens with cumulative top_p above the threshold (token with 0 are kept) | |
| sorted_indices_to_remove = cumulative_probs <= (1 - self.top_p) | |
| # Keep at least min_tokens_to_keep | |
| sorted_indices_to_remove[..., -self.min_tokens_to_keep :] = 0 | |
| # scatter sorted tensors to original indexing | |
| indices_to_remove = sorted_indices_to_remove.scatter( | |
| 1, sorted_indices, sorted_indices_to_remove | |
| ) | |
| probs = probs.masked_fill(indices_to_remove, 0.0) | |
| probs = probs / probs.sum(dim=-1, keepdim=True) | |
| return probs | |
| class DisallowTokensInIndexRangeLogitsProcessor(LogitsProcessor): | |
| def __init__( | |
| self, token_ids: list[int], start_index: int, end_index: int | None = None | |
| ): | |
| self.token_ids = torch.tensor(token_ids) | |
| self.start_index = start_index | |
| self.end_index = end_index if end_index is not None else math.inf | |
| def __call__( | |
| self, input_ids: torch.LongTensor, logits: torch.FloatTensor | |
| ) -> torch.FloatTensor: | |
| current_index = input_ids.shape[1] | |
| if self.start_index <= current_index < self.end_index: | |
| logits[:, self.token_ids] = -math.inf | |
| return logits | |
| class DisallowTokensLogitsProcessor(DisallowTokensInIndexRangeLogitsProcessor): | |
| def __init__(self, token_ids: list[int]): | |
| super().__init__(token_ids, 0) | |
| class DisallowTokensAtIndexLogitsProcessor(DisallowTokensInIndexRangeLogitsProcessor): | |
| def __init__(self, token_ids: list[int], index: int): | |
| super().__init__(token_ids, index, index + 1) | |
| class DisallowTokensAfterIndexLogitsProcessor( | |
| DisallowTokensInIndexRangeLogitsProcessor | |
| ): | |
| def __init__(self, token_ids: list[int], index: int): | |
| super().__init__(token_ids, index + 1) | |
| class DisallowTokensAtOrAfterIndexLogitsProcessor( | |
| DisallowTokensInIndexRangeLogitsProcessor | |
| ): | |
| def __init__(self, token_ids: list[int], index: int): | |
| super().__init__(token_ids, index) | |
| class DisallowTokensInBatchIndexRangeLogitsProcessor(LogitsProcessor): | |
| def __init__( | |
| self, | |
| token_ids: list[int], | |
| start_indices: list[int], | |
| end_indices: list[int] | None = None, | |
| ): | |
| self.token_ids = torch.tensor(token_ids) | |
| self.start_indices = torch.tensor(start_indices) | |
| self.end_indices = ( | |
| torch.tensor(end_indices) | |
| if end_indices is not None | |
| else torch.full_like(self.start_indices, math.inf, dtype=torch.float) | |
| ) | |
| def __call__( | |
| self, input_ids: torch.LongTensor, logits: torch.FloatTensor | |
| ) -> torch.FloatTensor: | |
| # input_ids.shape = [batch, seq_len] | |
| # logits.shape = [batch, vocab] | |
| current_index = input_ids.shape[1] | |
| mask = (self.start_indices <= current_index) & ( | |
| current_index < self.end_indices | |
| ) | |
| # The following will fail if the mask is all False. | |
| # logits[mask, self.token_ids] = -math.inf | |
| logits[torch.where(mask)[0].unsqueeze(1), self.token_ids] = -math.inf | |
| return logits | |
| class DisallowTokensAtBatchIndexLogitsProcessor( | |
| DisallowTokensInBatchIndexRangeLogitsProcessor | |
| ): | |
| def __init__(self, token_ids: list[int], batch_index: list[int]): | |
| super().__init__(token_ids, batch_index, [i + 1 for i in batch_index]) | |
| class AllowOnlyTokensInIndexRangeLogitsProcessor(LogitsProcessor): | |
| def __init__( | |
| self, token_ids: list[int], start_index: int, end_index: int | None = None | |
| ): | |
| self.token_ids = torch.tensor(token_ids) | |
| self.start_index = start_index | |
| self.end_index = end_index if end_index is not None else math.inf | |
| def __call__( | |
| self, input_ids: torch.LongTensor, logits: torch.FloatTensor | |
| ) -> torch.FloatTensor: | |
| current_index = input_ids.shape[1] | |
| if self.start_index <= current_index < self.end_index: | |
| replacement = torch.full_like(logits, -math.inf) | |
| replacement[:, self.token_ids] = logits[:, self.token_ids] | |
| logits[:] = replacement | |
| return logits | |
| class AllowOnlyTokensLogitsProcessor(AllowOnlyTokensInIndexRangeLogitsProcessor): | |
| def __init__(self, token_ids: list[int]): | |
| super().__init__(token_ids, 0) | |
| class AllowOnlyTokensAtIndexLogitsProcessor(AllowOnlyTokensInIndexRangeLogitsProcessor): | |
| def __init__(self, token_ids: list[int], index: int): | |
| super().__init__(token_ids, index, index + 1) | |
| class AllowOnlyTokensAfterIndexLogitsProcessor( | |
| AllowOnlyTokensInIndexRangeLogitsProcessor | |
| ): | |
| def __init__(self, token_ids: list[int], index: int): | |
| super().__init__(token_ids, index + 1) | |
| class AllowOnlyTokensAtOrAfterIndexLogitsProcessor( | |
| AllowOnlyTokensInIndexRangeLogitsProcessor | |
| ): | |
| def __init__(self, token_ids: list[int], index: int): | |
| super().__init__(token_ids, index) | |
| class AllowOnlyTokensInBatchIndexRangeLogitsProcessor(LogitsProcessor): | |
| def __init__( | |
| self, | |
| token_ids: list[int], | |
| start_indices: list[int], | |
| end_indices: list[int] | None = None, | |
| ): | |
| self.token_ids = torch.tensor(token_ids) | |
| self.start_indices = torch.tensor(start_indices) | |
| self.end_indices = ( | |
| torch.tensor(end_indices) | |
| if end_indices is not None | |
| else torch.full_like(self.start_indices, math.inf, dtype=torch.float) | |
| ) | |
| def __call__( | |
| self, input_ids: torch.LongTensor, logits: torch.FloatTensor | |
| ) -> torch.FloatTensor: | |
| # input_ids.shape = [batch, seq_len] | |
| # logits.shape = [batch, vocab] | |
| current_index = input_ids.shape[1] | |
| mask = (self.start_indices <= current_index) & ( | |
| current_index < self.end_indices | |
| ) | |
| valid_batch_indices = torch.where(mask)[0].unsqueeze(1) | |
| full_mask = torch.full_like(logits, -math.inf) | |
| full_mask[valid_batch_indices, self.token_ids] = logits[ | |
| valid_batch_indices, self.token_ids | |
| ] | |
| logits[:] = torch.where(full_mask != -math.inf, full_mask, logits) | |
| return logits | |
| class AllowOnlyTokensAtRelativeOffsetLogitsProcessor(LogitsProcessor): | |
| def __init__( | |
| self, trigger_token_id: int, subsequent_token_ids: list[int], offset: int | |
| ): | |
| self.trigger_token_id = trigger_token_id | |
| self.subsequent_token_ids = torch.tensor(subsequent_token_ids) | |
| self.offset = offset | |
| def __call__( | |
| self, input_ids: torch.LongTensor, logits: torch.FloatTensor | |
| ) -> torch.FloatTensor: | |
| # input_ids.shape=[batch, seq_len] | |
| # logits.shape=[batch, vocab] | |
| if input_ids.shape[1] < self.offset: | |
| return logits | |
| trigger_positions = ( | |
| input_ids[:, -self.offset] == self.trigger_token_id | |
| ).unsqueeze(-1) | |
| disallowed_tokens_mask = torch.ones_like(logits, dtype=bool) | |
| disallowed_tokens_mask[:, self.subsequent_token_ids] = False | |
| return logits.masked_fill_( | |
| disallowed_tokens_mask & trigger_positions, | |
| -math.inf, | |
| ) | |
| class AllowOnlyTokensInRelativeWindowLogitsProcessor(LogitsProcessor): | |
| def __init__(self, trigger_token_id: int, allowed_token_ids: list[int], width: int): | |
| self.trigger_token_id = trigger_token_id | |
| self.allowed_token_ids = torch.tensor(allowed_token_ids).unsqueeze( | |
| 0 | |
| ) # shape: [1, num_allowed_tokens] | |
| self.width = width | |
| def __call__( | |
| self, input_ids: torch.LongTensor, logits: torch.FloatTensor | |
| ) -> torch.FloatTensor: | |
| # input_ids.shape=[batch, seq_len] | |
| # logits.shape=[batch, vocab] | |
| width = min(self.width, input_ids.shape[1]) | |
| trigger_positions = ( | |
| (input_ids[:, -width:] == self.trigger_token_id).any(dim=1).unsqueeze(-1) | |
| ) | |
| disallowed_tokens_mask = torch.ones_like(logits, dtype=bool) | |
| disallowed_tokens_mask[:, self.allowed_token_ids] = False | |
| return logits.masked_fill_( | |
| disallowed_tokens_mask & trigger_positions, | |
| -math.inf, | |
| ) | |
| class CFGLogitsProcessor(LogitsProcessor): | |
| def __init__( | |
| self, | |
| guidance_scale: float, | |
| unconditional_ids: torch.LongTensor, | |
| model, | |
| ): | |
| self.guidance_scale = guidance_scale | |
| self.unconditional_ids = unconditional_ids | |
| self.model = model | |
| def __call__( | |
| self, input_ids: torch.LongTensor, logits: torch.FloatTensor | |
| ) -> torch.FloatTensor: | |
| conditioned_logits = logits | |
| self.unconditional_ids = torch.cat( | |
| [self.unconditional_ids, input_ids[:, -1:]], dim=1 | |
| ) | |
| unconditioned_outputs = self.model(self.unconditional_ids) | |
| unconditioned_logits = unconditioned_outputs[:, -1, :] | |
| return ( | |
| self.guidance_scale * (conditioned_logits - unconditioned_logits) | |
| + unconditioned_logits | |
| ) | |
| class InBatchCFGLogitsProcessor(LogitsProcessor): | |
| def __init__(self, guidance_scale: float): | |
| self.guidance_scale = guidance_scale | |
| def __call__( | |
| self, input_ids: torch.LongTensor, logits: torch.FloatTensor | |
| ) -> torch.FloatTensor: | |
| # input_ids.shape=[2*batch, seq-len] | |
| # logits.shape=[2*batch, vocab] | |
| conditioned_logits, unconditioned_logits = torch.chunk(logits, chunks=2, dim=0) | |
| mixed_logits = unconditioned_logits + self.guidance_scale * ( | |
| conditioned_logits - unconditioned_logits | |
| ) | |
| return mixed_logits.repeat(2, 1) | |
| class InBatchInstructCFGLogitsProcessor(LogitsProcessor): | |
| # See https://arxiv.org/abs/2211.09800 | |
| def __init__(self, guidance_scale_text: float, guidance_scale_image: float): | |
| self.guidance_scale_text = guidance_scale_text | |
| self.guidance_scale_image = guidance_scale_image | |
| def __call__( | |
| self, input_ids: torch.LongTensor, logits: torch.FloatTensor | |
| ) -> torch.FloatTensor: | |
| # input_ids.shape=[3*batch, seq-len] | |
| # logits.shape=[3*batch, vocab] | |
| ( | |
| full_conditioned_logits, | |
| image_conditioned_logits, | |
| unconditioned_logits, | |
| ) = logits.chunk(3) | |
| mixed_logits = ( | |
| unconditioned_logits | |
| + self.guidance_scale_image | |
| * (image_conditioned_logits - unconditioned_logits) | |
| + self.guidance_scale_text | |
| * (full_conditioned_logits - image_conditioned_logits) | |
| ) | |
| return mixed_logits.repeat(3, 1) | |