Merge pull request #177 from NanoCode012/fix/landmark-patch
Browse files- scripts/finetune.py +9 -0
- src/axolotl/monkeypatch/llama_landmark_attn.py +56 -402
- src/axolotl/utils/models.py +7 -14
- src/axolotl/utils/trainer.py +7 -4
    	
        scripts/finetune.py
    CHANGED
    
    | @@ -77,6 +77,14 @@ def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"): | |
| 77 | 
             
                        importlib.import_module("axolotl.prompters"), prompter
         | 
| 78 | 
             
                    )
         | 
| 79 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 80 | 
             
                while True:
         | 
| 81 | 
             
                    print("=" * 80)
         | 
| 82 | 
             
                    # support for multiline inputs
         | 
| @@ -90,6 +98,7 @@ def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"): | |
| 90 | 
             
                    else:
         | 
| 91 | 
             
                        prompt = instruction.strip()
         | 
| 92 | 
             
                    batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
         | 
|  | |
| 93 | 
             
                    print("=" * 40)
         | 
| 94 | 
             
                    model.eval()
         | 
| 95 | 
             
                    with torch.no_grad():
         | 
|  | |
| 77 | 
             
                        importlib.import_module("axolotl.prompters"), prompter
         | 
| 78 | 
             
                    )
         | 
| 79 |  | 
| 80 | 
            +
                if cfg.landmark_attention:
         | 
| 81 | 
            +
                    from axolotl.monkeypatch.llama_landmark_attn import set_model_mem_id
         | 
| 82 | 
            +
             | 
| 83 | 
            +
                    set_model_mem_id(model, tokenizer)
         | 
| 84 | 
            +
                    model.set_mem_cache_args(
         | 
| 85 | 
            +
                        max_seq_len=255, mem_freq=50, top_k=5, max_cache_size=None
         | 
| 86 | 
            +
                    )
         | 
| 87 | 
            +
             | 
| 88 | 
             
                while True:
         | 
| 89 | 
             
                    print("=" * 80)
         | 
| 90 | 
             
                    # support for multiline inputs
         | 
|  | |
| 98 | 
             
                    else:
         | 
| 99 | 
             
                        prompt = instruction.strip()
         | 
| 100 | 
             
                    batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
         | 
| 101 | 
            +
             | 
| 102 | 
             
                    print("=" * 40)
         | 
| 103 | 
             
                    model.eval()
         | 
| 104 | 
             
                    with torch.no_grad():
         | 
    	
        src/axolotl/monkeypatch/llama_landmark_attn.py
    CHANGED
    
    | @@ -28,15 +28,24 @@ from typing import List, Optional, Tuple, Union | |
| 28 | 
             
            import torch
         | 
| 29 | 
             
            import torch.utils.checkpoint
         | 
| 30 | 
             
            from torch import nn
         | 
| 31 | 
            -
            from torch.nn import  | 
| 32 | 
            -
            from transformers | 
| 33 | 
             
            from transformers.modeling_outputs import (
         | 
| 34 | 
             
                BaseModelOutputWithPast,
         | 
| 35 | 
             
                CausalLMOutputWithPast,
         | 
| 36 | 
            -
                SequenceClassifierOutputWithPast,
         | 
| 37 | 
             
            )
         | 
| 38 | 
            -
            from transformers.modeling_utils import PreTrainedModel
         | 
| 39 | 
             
            from transformers.models.llama.configuration_llama import LlamaConfig
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 40 | 
             
            from transformers.utils import (
         | 
| 41 | 
             
                add_start_docstrings,
         | 
| 42 | 
             
                add_start_docstrings_to_model_forward,
         | 
| @@ -51,131 +60,6 @@ _CONFIG_FOR_DOC = "LlamaConfig" | |
| 51 | 
             
            MEM_TOKEN = "<landmark>"  # nosec
         | 
| 52 |  | 
| 53 |  | 
| 54 | 
            -
            # Copied from transformers.models.bart.modeling_bart._make_causal_mask
         | 
| 55 | 
            -
            def _make_causal_mask(
         | 
| 56 | 
            -
                input_ids_shape: torch.Size,
         | 
| 57 | 
            -
                dtype: torch.dtype,
         | 
| 58 | 
            -
                device: torch.device,
         | 
| 59 | 
            -
                past_key_values_length: int = 0,
         | 
| 60 | 
            -
            ):
         | 
| 61 | 
            -
                """
         | 
| 62 | 
            -
                Make causal mask used for bi-directional self-attention.
         | 
| 63 | 
            -
                """
         | 
| 64 | 
            -
                bsz, tgt_len = input_ids_shape
         | 
| 65 | 
            -
                mask = torch.full(
         | 
| 66 | 
            -
                    (tgt_len, tgt_len),
         | 
| 67 | 
            -
                    torch.tensor(torch.finfo(dtype).min, device=device),
         | 
| 68 | 
            -
                    device=device,
         | 
| 69 | 
            -
                )
         | 
| 70 | 
            -
                mask_cond = torch.arange(mask.size(-1), device=device)
         | 
| 71 | 
            -
                mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
         | 
| 72 | 
            -
                mask = mask.to(dtype)
         | 
| 73 | 
            -
             | 
| 74 | 
            -
                if past_key_values_length > 0:
         | 
| 75 | 
            -
                    mask = torch.cat(
         | 
| 76 | 
            -
                        [
         | 
| 77 | 
            -
                            torch.zeros(
         | 
| 78 | 
            -
                                tgt_len, past_key_values_length, dtype=dtype, device=device
         | 
| 79 | 
            -
                            ),
         | 
| 80 | 
            -
                            mask,
         | 
| 81 | 
            -
                        ],
         | 
| 82 | 
            -
                        dim=-1,
         | 
| 83 | 
            -
                    )
         | 
| 84 | 
            -
                return mask[None, None, :, :].expand(
         | 
| 85 | 
            -
                    bsz, 1, tgt_len, tgt_len + past_key_values_length
         | 
| 86 | 
            -
                )
         | 
| 87 | 
            -
             | 
| 88 | 
            -
             | 
| 89 | 
            -
            # Copied from transformers.models.bart.modeling_bart._expand_mask
         | 
| 90 | 
            -
            def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
         | 
| 91 | 
            -
                """
         | 
| 92 | 
            -
                Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
         | 
| 93 | 
            -
                """
         | 
| 94 | 
            -
                bsz, src_len = mask.size()
         | 
| 95 | 
            -
                tgt_len = tgt_len if tgt_len is not None else src_len
         | 
| 96 | 
            -
             | 
| 97 | 
            -
                expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
         | 
| 98 | 
            -
             | 
| 99 | 
            -
                inverted_mask = 1.0 - expanded_mask
         | 
| 100 | 
            -
             | 
| 101 | 
            -
                return inverted_mask.masked_fill(
         | 
| 102 | 
            -
                    inverted_mask.to(torch.bool), torch.finfo(dtype).min
         | 
| 103 | 
            -
                )
         | 
| 104 | 
            -
             | 
| 105 | 
            -
             | 
| 106 | 
            -
            class LlamaRMSNorm(nn.Module):
         | 
| 107 | 
            -
                def __init__(self, hidden_size, eps=1e-6):
         | 
| 108 | 
            -
                    """
         | 
| 109 | 
            -
                    LlamaRMSNorm is equivalent to T5LayerNorm
         | 
| 110 | 
            -
                    """
         | 
| 111 | 
            -
                    super().__init__()
         | 
| 112 | 
            -
                    self.weight = nn.Parameter(torch.ones(hidden_size))
         | 
| 113 | 
            -
                    self.variance_epsilon = eps
         | 
| 114 | 
            -
             | 
| 115 | 
            -
                def forward(self, hidden_states):
         | 
| 116 | 
            -
                    variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
         | 
| 117 | 
            -
                    hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
         | 
| 118 | 
            -
             | 
| 119 | 
            -
                    # convert into half-precision if necessary
         | 
| 120 | 
            -
                    if self.weight.dtype in [torch.float16, torch.bfloat16]:
         | 
| 121 | 
            -
                        hidden_states = hidden_states.to(self.weight.dtype)
         | 
| 122 | 
            -
             | 
| 123 | 
            -
                    return self.weight * hidden_states
         | 
| 124 | 
            -
             | 
| 125 | 
            -
             | 
| 126 | 
            -
            class LlamaRotaryEmbedding(torch.nn.Module):
         | 
| 127 | 
            -
                def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
         | 
| 128 | 
            -
                    super().__init__()
         | 
| 129 | 
            -
                    inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
         | 
| 130 | 
            -
                    self.register_buffer("inv_freq", inv_freq)
         | 
| 131 | 
            -
             | 
| 132 | 
            -
                    # Build here to make `torch.jit.trace` work.
         | 
| 133 | 
            -
                    self.max_seq_len_cached = max_position_embeddings
         | 
| 134 | 
            -
                    t = torch.arange(
         | 
| 135 | 
            -
                        self.max_seq_len_cached,
         | 
| 136 | 
            -
                        device=self.inv_freq.device,
         | 
| 137 | 
            -
                        dtype=self.inv_freq.dtype,
         | 
| 138 | 
            -
                    )
         | 
| 139 | 
            -
                    freqs = torch.einsum("i,j->ij", t, self.inv_freq)
         | 
| 140 | 
            -
                    # Different from paper, but it uses a different permutation in order to obtain the same calculation
         | 
| 141 | 
            -
                    emb = torch.cat((freqs, freqs), dim=-1)
         | 
| 142 | 
            -
                    self.register_buffer(
         | 
| 143 | 
            -
                        "cos_cached", emb.cos()[None, None, :, :], persistent=False
         | 
| 144 | 
            -
                    )
         | 
| 145 | 
            -
                    self.register_buffer(
         | 
| 146 | 
            -
                        "sin_cached", emb.sin()[None, None, :, :], persistent=False
         | 
| 147 | 
            -
                    )
         | 
| 148 | 
            -
             | 
| 149 | 
            -
                def forward(self, x, seq_len=None):
         | 
| 150 | 
            -
                    # x: [bs, num_attention_heads, seq_len, head_size]
         | 
| 151 | 
            -
                    # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
         | 
| 152 | 
            -
                    if seq_len > self.max_seq_len_cached:
         | 
| 153 | 
            -
                        self.max_seq_len_cached = seq_len
         | 
| 154 | 
            -
                        t = torch.arange(
         | 
| 155 | 
            -
                            self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype
         | 
| 156 | 
            -
                        )
         | 
| 157 | 
            -
                        freqs = torch.einsum("i,j->ij", t, self.inv_freq)
         | 
| 158 | 
            -
                        # Different from paper, but it uses a different permutation in order to obtain the same calculation
         | 
| 159 | 
            -
                        emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
         | 
| 160 | 
            -
                        self.register_buffer(
         | 
| 161 | 
            -
                            "cos_cached", emb.cos()[None, None, :, :], persistent=False
         | 
| 162 | 
            -
                        )
         | 
| 163 | 
            -
                        self.register_buffer(
         | 
| 164 | 
            -
                            "sin_cached", emb.sin()[None, None, :, :], persistent=False
         | 
| 165 | 
            -
                        )
         | 
| 166 | 
            -
                    return (
         | 
| 167 | 
            -
                        self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
         | 
| 168 | 
            -
                        self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
         | 
| 169 | 
            -
                    )
         | 
| 170 | 
            -
             | 
| 171 | 
            -
             | 
| 172 | 
            -
            def rotate_half(x):
         | 
| 173 | 
            -
                """Rotates half the hidden dims of the input."""
         | 
| 174 | 
            -
                x1 = x[..., : x.shape[-1] // 2]
         | 
| 175 | 
            -
                x2 = x[..., x.shape[-1] // 2 :]
         | 
| 176 | 
            -
                return torch.cat((-x2, x1), dim=-1)
         | 
| 177 | 
            -
             | 
| 178 | 
            -
             | 
| 179 | 
             
            def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
         | 
| 180 | 
             
                # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
         | 
| 181 | 
             
                cos = cos.squeeze(1).squeeze(0)  # [seq_len, dim]
         | 
| @@ -190,24 +74,11 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids): | |
| 190 | 
             
                return q_embed, k_embed
         | 
| 191 |  | 
| 192 |  | 
| 193 | 
            -
            class LlamaMLP(nn.Module):
         | 
| 194 | 
            -
                def __init__(
         | 
| 195 | 
            -
                    self,
         | 
| 196 | 
            -
                    hidden_size: int,
         | 
| 197 | 
            -
                    intermediate_size: int,
         | 
| 198 | 
            -
                    hidden_act: str,
         | 
| 199 | 
            -
                ):
         | 
| 200 | 
            -
                    super().__init__()
         | 
| 201 | 
            -
                    self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
         | 
| 202 | 
            -
                    self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
         | 
| 203 | 
            -
                    self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
         | 
| 204 | 
            -
                    self.act_fn = ACT2FN[hidden_act]
         | 
| 205 | 
            -
             | 
| 206 | 
            -
                def forward(self, x):
         | 
| 207 | 
            -
                    return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
         | 
| 208 | 
            -
             | 
| 209 | 
            -
             | 
| 210 | 
             
            class LandmarkGroupedSoftmaxFunction(torch.autograd.Function):
         | 
|  | |
|  | |
|  | |
|  | |
| 211 | 
             
                # Note that forward, setup_context, and backward are @staticmethods
         | 
| 212 | 
             
                @staticmethod
         | 
| 213 | 
             
                def forward(ctx, x, dim, mem_cnt, resp_mem_idx):
         | 
| @@ -682,16 +553,14 @@ class LlamaAttention(nn.Module): | |
| 682 | 
             
                    # upcast attention to fp32
         | 
| 683 | 
             
                    if is_mem is None:
         | 
| 684 | 
             
                        raise ValueError("Don't use this without landmarks")
         | 
| 685 | 
            -
             | 
| 686 | 
            -
             | 
| 687 | 
            -
                         | 
| 688 | 
            -
             | 
| 689 | 
            -
                         | 
| 690 | 
            -
             | 
| 691 | 
            -
             | 
| 692 | 
            -
             | 
| 693 | 
            -
                            last_section_mask=last_section_mask,
         | 
| 694 | 
            -
                        ).to(query_states.dtype)
         | 
| 695 | 
             
                    if attn_prefix is not None:
         | 
| 696 | 
             
                        attn_prefix, attn_weights = torch.split(
         | 
| 697 | 
             
                            attn_weights,
         | 
| @@ -722,6 +591,10 @@ class LlamaAttention(nn.Module): | |
| 722 |  | 
| 723 |  | 
| 724 | 
             
            class LlamaDecoderLayer(nn.Module):
         | 
|  | |
|  | |
|  | |
|  | |
| 725 | 
             
                def __init__(self, config: LlamaConfig):
         | 
| 726 | 
             
                    super().__init__()
         | 
| 727 | 
             
                    self.hidden_size = config.hidden_size
         | 
| @@ -802,114 +675,6 @@ class LlamaDecoderLayer(nn.Module): | |
| 802 | 
             
                    return outputs
         | 
| 803 |  | 
| 804 |  | 
| 805 | 
            -
            LLAMA_START_DOCSTRING = r"""
         | 
| 806 | 
            -
                This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
         | 
| 807 | 
            -
                library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
         | 
| 808 | 
            -
                etc.)
         | 
| 809 | 
            -
             | 
| 810 | 
            -
                This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
         | 
| 811 | 
            -
                Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
         | 
| 812 | 
            -
                and behavior.
         | 
| 813 | 
            -
             | 
| 814 | 
            -
                Parameters:
         | 
| 815 | 
            -
                    config ([`LlamaConfig`]):
         | 
| 816 | 
            -
                        Model configuration class with all the parameters of the model. Initializing with a config file does not
         | 
| 817 | 
            -
                        load the weights associated with the model, only the configuration. Check out the
         | 
| 818 | 
            -
                        [`~PreTrainedModel.from_pretrained`] method to load the model weights.
         | 
| 819 | 
            -
            """
         | 
| 820 | 
            -
             | 
| 821 | 
            -
             | 
| 822 | 
            -
            @add_start_docstrings(
         | 
| 823 | 
            -
                "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
         | 
| 824 | 
            -
                LLAMA_START_DOCSTRING,
         | 
| 825 | 
            -
            )
         | 
| 826 | 
            -
            class LlamaPreTrainedModel(PreTrainedModel):
         | 
| 827 | 
            -
                config_class = LlamaConfig
         | 
| 828 | 
            -
                base_model_prefix = "model"
         | 
| 829 | 
            -
                supports_gradient_checkpointing = True
         | 
| 830 | 
            -
                _no_split_modules = ["LlamaDecoderLayer"]
         | 
| 831 | 
            -
                _keys_to_ignore_on_load_unexpected = [r"decoder\.version"]
         | 
| 832 | 
            -
             | 
| 833 | 
            -
                def _init_weights(self, module):
         | 
| 834 | 
            -
                    std = self.config.initializer_range
         | 
| 835 | 
            -
                    if isinstance(module, nn.Linear):
         | 
| 836 | 
            -
                        module.weight.data.normal_(mean=0.0, std=std)
         | 
| 837 | 
            -
                        if module.bias is not None:
         | 
| 838 | 
            -
                            module.bias.data.zero_()
         | 
| 839 | 
            -
                    elif isinstance(module, nn.Embedding):
         | 
| 840 | 
            -
                        module.weight.data.normal_(mean=0.0, std=std)
         | 
| 841 | 
            -
                        if module.padding_idx is not None:
         | 
| 842 | 
            -
                            module.weight.data[module.padding_idx].zero_()
         | 
| 843 | 
            -
             | 
| 844 | 
            -
                def _set_gradient_checkpointing(self, module, value=False):
         | 
| 845 | 
            -
                    if isinstance(module, LlamaModel):
         | 
| 846 | 
            -
                        module.gradient_checkpointing = value
         | 
| 847 | 
            -
             | 
| 848 | 
            -
             | 
| 849 | 
            -
            LLAMA_INPUTS_DOCSTRING = r"""
         | 
| 850 | 
            -
                Args:
         | 
| 851 | 
            -
                    input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
         | 
| 852 | 
            -
                        Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
         | 
| 853 | 
            -
                        it.
         | 
| 854 | 
            -
             | 
| 855 | 
            -
                        Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
         | 
| 856 | 
            -
                        [`PreTrainedTokenizer.__call__`] for details.
         | 
| 857 | 
            -
             | 
| 858 | 
            -
                        [What are input IDs?](../glossary#input-ids)
         | 
| 859 | 
            -
                    attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
         | 
| 860 | 
            -
                        Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
         | 
| 861 | 
            -
             | 
| 862 | 
            -
                        - 1 for tokens that are **not masked**,
         | 
| 863 | 
            -
                        - 0 for tokens that are **masked**.
         | 
| 864 | 
            -
             | 
| 865 | 
            -
                        [What are attention masks?](../glossary#attention-mask)
         | 
| 866 | 
            -
             | 
| 867 | 
            -
                        Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
         | 
| 868 | 
            -
                        [`PreTrainedTokenizer.__call__`] for details.
         | 
| 869 | 
            -
             | 
| 870 | 
            -
                        If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
         | 
| 871 | 
            -
                        `past_key_values`).
         | 
| 872 | 
            -
             | 
| 873 | 
            -
                        If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
         | 
| 874 | 
            -
                        and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
         | 
| 875 | 
            -
                        information on the default strategy.
         | 
| 876 | 
            -
             | 
| 877 | 
            -
                        - 1 indicates the head is **not masked**,
         | 
| 878 | 
            -
                        - 0 indicates the head is **masked**.
         | 
| 879 | 
            -
                    position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
         | 
| 880 | 
            -
                        Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
         | 
| 881 | 
            -
                        config.n_positions - 1]`.
         | 
| 882 | 
            -
             | 
| 883 | 
            -
                        [What are position IDs?](../glossary#position-ids)
         | 
| 884 | 
            -
                    past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
         | 
| 885 | 
            -
                        Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
         | 
| 886 | 
            -
                        `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
         | 
| 887 | 
            -
                        `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
         | 
| 888 | 
            -
             | 
| 889 | 
            -
                        Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
         | 
| 890 | 
            -
                        blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
         | 
| 891 | 
            -
             | 
| 892 | 
            -
                        If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
         | 
| 893 | 
            -
                        don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
         | 
| 894 | 
            -
                        `decoder_input_ids` of shape `(batch_size, sequence_length)`.
         | 
| 895 | 
            -
                    inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
         | 
| 896 | 
            -
                        Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
         | 
| 897 | 
            -
                        is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
         | 
| 898 | 
            -
                        model's internal embedding lookup matrix.
         | 
| 899 | 
            -
                    use_cache (`bool`, *optional*):
         | 
| 900 | 
            -
                        If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
         | 
| 901 | 
            -
                        `past_key_values`).
         | 
| 902 | 
            -
                    output_attentions (`bool`, *optional*):
         | 
| 903 | 
            -
                        Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
         | 
| 904 | 
            -
                        tensors for more detail.
         | 
| 905 | 
            -
                    output_hidden_states (`bool`, *optional*):
         | 
| 906 | 
            -
                        Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
         | 
| 907 | 
            -
                        more detail.
         | 
| 908 | 
            -
                    return_dict (`bool`, *optional*):
         | 
| 909 | 
            -
                        Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
         | 
| 910 | 
            -
            """
         | 
| 911 | 
            -
             | 
| 912 | 
            -
             | 
| 913 | 
             
            @add_start_docstrings(
         | 
| 914 | 
             
                "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
         | 
| 915 | 
             
                LLAMA_START_DOCSTRING,
         | 
| @@ -1178,6 +943,10 @@ class LlamaModel(LlamaPreTrainedModel): | |
| 1178 |  | 
| 1179 |  | 
| 1180 | 
             
            class LlamaForCausalLM(LlamaPreTrainedModel):
         | 
|  | |
|  | |
|  | |
|  | |
| 1181 | 
             
                def __init__(self, config):
         | 
| 1182 | 
             
                    super().__init__(config)
         | 
| 1183 | 
             
                    self.model = LlamaModel(config)
         | 
| @@ -1448,148 +1217,33 @@ class LlamaForCausalLM(LlamaPreTrainedModel): | |
| 1448 | 
             
                    return reordered_past
         | 
| 1449 |  | 
| 1450 |  | 
| 1451 | 
            -
            @add_start_docstrings(
         | 
| 1452 | 
            -
                """
         | 
| 1453 | 
            -
                The LLaMa Model transformer with a sequence classification head on top (linear layer).
         | 
| 1454 | 
            -
             | 
| 1455 | 
            -
                [`LlamaForSequenceClassification`] uses the last token in order to do the classification, as other causal models
         | 
| 1456 | 
            -
                (e.g. GPT-2) do.
         | 
| 1457 | 
            -
             | 
| 1458 | 
            -
                Since it does classification on the last token, it requires to know the position of the last token. If a
         | 
| 1459 | 
            -
                `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
         | 
| 1460 | 
            -
                no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
         | 
| 1461 | 
            -
                padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
         | 
| 1462 | 
            -
                each row of the batch).
         | 
| 1463 | 
            -
                """,
         | 
| 1464 | 
            -
                LLAMA_START_DOCSTRING,
         | 
| 1465 | 
            -
            )
         | 
| 1466 | 
            -
            class LlamaForSequenceClassification(LlamaPreTrainedModel):
         | 
| 1467 | 
            -
                _keys_to_ignore_on_load_missing = [r"lm_head.weight"]
         | 
| 1468 | 
            -
             | 
| 1469 | 
            -
                def __init__(self, config):
         | 
| 1470 | 
            -
                    super().__init__(config)
         | 
| 1471 | 
            -
                    self.num_labels = config.num_labels
         | 
| 1472 | 
            -
                    self.model = LlamaModel(config)
         | 
| 1473 | 
            -
                    self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
         | 
| 1474 | 
            -
             | 
| 1475 | 
            -
                    # Initialize weights and apply final processing
         | 
| 1476 | 
            -
                    self.post_init()
         | 
| 1477 | 
            -
             | 
| 1478 | 
            -
                def get_input_embeddings(self):
         | 
| 1479 | 
            -
                    return self.model.embed_tokens
         | 
| 1480 | 
            -
             | 
| 1481 | 
            -
                def set_input_embeddings(self, value):
         | 
| 1482 | 
            -
                    self.model.embed_tokens = value
         | 
| 1483 | 
            -
             | 
| 1484 | 
            -
                @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
         | 
| 1485 | 
            -
                def forward(
         | 
| 1486 | 
            -
                    self,
         | 
| 1487 | 
            -
                    input_ids: torch.LongTensor = None,
         | 
| 1488 | 
            -
                    attention_mask: Optional[torch.Tensor] = None,
         | 
| 1489 | 
            -
                    position_ids: Optional[torch.LongTensor] = None,
         | 
| 1490 | 
            -
                    past_key_values: Optional[List[torch.FloatTensor]] = None,
         | 
| 1491 | 
            -
                    inputs_embeds: Optional[torch.FloatTensor] = None,
         | 
| 1492 | 
            -
                    labels: Optional[torch.LongTensor] = None,
         | 
| 1493 | 
            -
                    use_cache: Optional[bool] = None,
         | 
| 1494 | 
            -
                    output_attentions: Optional[bool] = None,
         | 
| 1495 | 
            -
                    output_hidden_states: Optional[bool] = None,
         | 
| 1496 | 
            -
                    return_dict: Optional[bool] = None,
         | 
| 1497 | 
            -
                ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
         | 
| 1498 | 
            -
                    r"""
         | 
| 1499 | 
            -
                    labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
         | 
| 1500 | 
            -
                        Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
         | 
| 1501 | 
            -
                        config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
         | 
| 1502 | 
            -
                        `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
         | 
| 1503 | 
            -
                    """
         | 
| 1504 | 
            -
                    return_dict = (
         | 
| 1505 | 
            -
                        return_dict if return_dict is not None else self.config.use_return_dict
         | 
| 1506 | 
            -
                    )
         | 
| 1507 | 
            -
             | 
| 1508 | 
            -
                    transformer_outputs = self.model(
         | 
| 1509 | 
            -
                        input_ids,
         | 
| 1510 | 
            -
                        attention_mask=attention_mask,
         | 
| 1511 | 
            -
                        position_ids=position_ids,
         | 
| 1512 | 
            -
                        past_key_values=past_key_values,
         | 
| 1513 | 
            -
                        inputs_embeds=inputs_embeds,
         | 
| 1514 | 
            -
                        use_cache=use_cache,
         | 
| 1515 | 
            -
                        output_attentions=output_attentions,
         | 
| 1516 | 
            -
                        output_hidden_states=output_hidden_states,
         | 
| 1517 | 
            -
                        return_dict=return_dict,
         | 
| 1518 | 
            -
                    )
         | 
| 1519 | 
            -
                    hidden_states = transformer_outputs[0]
         | 
| 1520 | 
            -
                    logits = self.score(hidden_states)
         | 
| 1521 | 
            -
             | 
| 1522 | 
            -
                    if input_ids is not None:
         | 
| 1523 | 
            -
                        batch_size = input_ids.shape[0]
         | 
| 1524 | 
            -
                    else:
         | 
| 1525 | 
            -
                        batch_size = inputs_embeds.shape[0]
         | 
| 1526 | 
            -
             | 
| 1527 | 
            -
                    if self.config.pad_token_id is None and batch_size != 1:
         | 
| 1528 | 
            -
                        raise ValueError(
         | 
| 1529 | 
            -
                            "Cannot handle batch sizes > 1 if no padding token is defined."
         | 
| 1530 | 
            -
                        )
         | 
| 1531 | 
            -
                    if self.config.pad_token_id is None:
         | 
| 1532 | 
            -
                        sequence_lengths = -1
         | 
| 1533 | 
            -
                    else:
         | 
| 1534 | 
            -
                        if input_ids is not None:
         | 
| 1535 | 
            -
                            sequence_lengths = (
         | 
| 1536 | 
            -
                                torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1
         | 
| 1537 | 
            -
                            ).to(logits.device)
         | 
| 1538 | 
            -
                        else:
         | 
| 1539 | 
            -
                            sequence_lengths = -1
         | 
| 1540 | 
            -
             | 
| 1541 | 
            -
                    pooled_logits = logits[
         | 
| 1542 | 
            -
                        torch.arange(batch_size, device=logits.device), sequence_lengths
         | 
| 1543 | 
            -
                    ]
         | 
| 1544 | 
            -
             | 
| 1545 | 
            -
                    loss = None
         | 
| 1546 | 
            -
                    if labels is not None:
         | 
| 1547 | 
            -
                        labels = labels.to(logits.device)
         | 
| 1548 | 
            -
                        if self.config.problem_type is None:
         | 
| 1549 | 
            -
                            if self.num_labels == 1:
         | 
| 1550 | 
            -
                                self.config.problem_type = "regression"
         | 
| 1551 | 
            -
                            elif self.num_labels > 1 and (
         | 
| 1552 | 
            -
                                labels.dtype == torch.long or labels.dtype == torch.int
         | 
| 1553 | 
            -
                            ):
         | 
| 1554 | 
            -
                                self.config.problem_type = "single_label_classification"
         | 
| 1555 | 
            -
                            else:
         | 
| 1556 | 
            -
                                self.config.problem_type = "multi_label_classification"
         | 
| 1557 | 
            -
             | 
| 1558 | 
            -
                        if self.config.problem_type == "regression":
         | 
| 1559 | 
            -
                            loss_fct = MSELoss()
         | 
| 1560 | 
            -
                            if self.num_labels == 1:
         | 
| 1561 | 
            -
                                loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
         | 
| 1562 | 
            -
                            else:
         | 
| 1563 | 
            -
                                loss = loss_fct(pooled_logits, labels)
         | 
| 1564 | 
            -
                        elif self.config.problem_type == "single_label_classification":
         | 
| 1565 | 
            -
                            loss_fct = CrossEntropyLoss()
         | 
| 1566 | 
            -
                            loss = loss_fct(
         | 
| 1567 | 
            -
                                pooled_logits.view(-1, self.num_labels), labels.view(-1)
         | 
| 1568 | 
            -
                            )
         | 
| 1569 | 
            -
                        elif self.config.problem_type == "multi_label_classification":
         | 
| 1570 | 
            -
                            loss_fct = BCEWithLogitsLoss()
         | 
| 1571 | 
            -
                            loss = loss_fct(pooled_logits, labels)
         | 
| 1572 | 
            -
                    if not return_dict:
         | 
| 1573 | 
            -
                        output = (pooled_logits,) + transformer_outputs[1:]
         | 
| 1574 | 
            -
                        return ((loss,) + output) if loss is not None else output
         | 
| 1575 | 
            -
             | 
| 1576 | 
            -
                    return SequenceClassifierOutputWithPast(
         | 
| 1577 | 
            -
                        loss=loss,
         | 
| 1578 | 
            -
                        logits=pooled_logits,
         | 
| 1579 | 
            -
                        past_key_values=transformer_outputs.past_key_values,
         | 
| 1580 | 
            -
                        hidden_states=transformer_outputs.hidden_states,
         | 
| 1581 | 
            -
                        attentions=transformer_outputs.attentions,
         | 
| 1582 | 
            -
                    )
         | 
| 1583 | 
            -
             | 
| 1584 | 
            -
             | 
| 1585 | 
             
            def add_mem_tokens(example, mem_freq, mem_id):
         | 
| 1586 | 
            -
                 | 
| 1587 | 
             
                ret = []
         | 
| 1588 | 
             
                prev_idx = 0
         | 
| 1589 | 
            -
                for t_idx in range(mem_freq, len( | 
| 1590 | 
            -
                    ret.extend( | 
| 1591 | 
             
                    ret.append(mem_id)
         | 
| 1592 | 
             
                    prev_idx = t_idx
         | 
| 1593 | 
            -
                ret.extend( | 
| 1594 | 
             
                # drop attention_mask
         | 
| 1595 | 
             
                return {"input_ids": ret}
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 28 | 
             
            import torch
         | 
| 29 | 
             
            import torch.utils.checkpoint
         | 
| 30 | 
             
            from torch import nn
         | 
| 31 | 
            +
            from torch.nn import CrossEntropyLoss
         | 
| 32 | 
            +
            from transformers import LlamaTokenizer
         | 
| 33 | 
             
            from transformers.modeling_outputs import (
         | 
| 34 | 
             
                BaseModelOutputWithPast,
         | 
| 35 | 
             
                CausalLMOutputWithPast,
         | 
|  | |
| 36 | 
             
            )
         | 
|  | |
| 37 | 
             
            from transformers.models.llama.configuration_llama import LlamaConfig
         | 
| 38 | 
            +
            from transformers.models.llama.modeling_llama import (
         | 
| 39 | 
            +
                LLAMA_INPUTS_DOCSTRING,
         | 
| 40 | 
            +
                LLAMA_START_DOCSTRING,
         | 
| 41 | 
            +
                LlamaMLP,
         | 
| 42 | 
            +
                LlamaPreTrainedModel,
         | 
| 43 | 
            +
                LlamaRMSNorm,
         | 
| 44 | 
            +
                LlamaRotaryEmbedding,
         | 
| 45 | 
            +
                _expand_mask,
         | 
| 46 | 
            +
                _make_causal_mask,
         | 
| 47 | 
            +
                rotate_half,
         | 
| 48 | 
            +
            )
         | 
| 49 | 
             
            from transformers.utils import (
         | 
| 50 | 
             
                add_start_docstrings,
         | 
| 51 | 
             
                add_start_docstrings_to_model_forward,
         | 
|  | |
| 60 | 
             
            MEM_TOKEN = "<landmark>"  # nosec
         | 
| 61 |  | 
| 62 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 63 | 
             
            def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
         | 
| 64 | 
             
                # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
         | 
| 65 | 
             
                cos = cos.squeeze(1).squeeze(0)  # [seq_len, dim]
         | 
|  | |
| 74 | 
             
                return q_embed, k_embed
         | 
| 75 |  | 
| 76 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 77 | 
             
            class LandmarkGroupedSoftmaxFunction(torch.autograd.Function):
         | 
| 78 | 
            +
                """
         | 
| 79 | 
            +
                Landmark grouped softmax function.
         | 
| 80 | 
            +
                """
         | 
| 81 | 
            +
             | 
| 82 | 
             
                # Note that forward, setup_context, and backward are @staticmethods
         | 
| 83 | 
             
                @staticmethod
         | 
| 84 | 
             
                def forward(ctx, x, dim, mem_cnt, resp_mem_idx):
         | 
|  | |
| 553 | 
             
                    # upcast attention to fp32
         | 
| 554 | 
             
                    if is_mem is None:
         | 
| 555 | 
             
                        raise ValueError("Don't use this without landmarks")
         | 
| 556 | 
            +
             | 
| 557 | 
            +
                    attn_weights = landmark_grouped_softmax(
         | 
| 558 | 
            +
                        attn_weights,
         | 
| 559 | 
            +
                        dim=-1,
         | 
| 560 | 
            +
                        is_mem=is_mem.expand(-1, self.num_heads, -1, -1),
         | 
| 561 | 
            +
                        last_section_mask=last_section_mask,
         | 
| 562 | 
            +
                    ).to(query_states.dtype)
         | 
| 563 | 
            +
             | 
|  | |
|  | |
| 564 | 
             
                    if attn_prefix is not None:
         | 
| 565 | 
             
                        attn_prefix, attn_weights = torch.split(
         | 
| 566 | 
             
                            attn_weights,
         | 
|  | |
| 591 |  | 
| 592 |  | 
| 593 | 
             
            class LlamaDecoderLayer(nn.Module):
         | 
| 594 | 
            +
                """
         | 
| 595 | 
            +
                Llama Decoder layer
         | 
| 596 | 
            +
                """
         | 
| 597 | 
            +
             | 
| 598 | 
             
                def __init__(self, config: LlamaConfig):
         | 
| 599 | 
             
                    super().__init__()
         | 
| 600 | 
             
                    self.hidden_size = config.hidden_size
         | 
|  | |
| 675 | 
             
                    return outputs
         | 
| 676 |  | 
| 677 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 678 | 
             
            @add_start_docstrings(
         | 
| 679 | 
             
                "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
         | 
| 680 | 
             
                LLAMA_START_DOCSTRING,
         | 
|  | |
| 943 |  | 
| 944 |  | 
| 945 | 
             
            class LlamaForCausalLM(LlamaPreTrainedModel):
         | 
| 946 | 
            +
                """
         | 
| 947 | 
            +
                Llama model with a causal language modeling head.
         | 
| 948 | 
            +
                """
         | 
| 949 | 
            +
             | 
| 950 | 
             
                def __init__(self, config):
         | 
| 951 | 
             
                    super().__init__(config)
         | 
| 952 | 
             
                    self.model = LlamaModel(config)
         | 
|  | |
| 1217 | 
             
                    return reordered_past
         | 
| 1218 |  | 
| 1219 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 1220 | 
             
            def add_mem_tokens(example, mem_freq, mem_id):
         | 
| 1221 | 
            +
                ids = example["input_ids"]
         | 
| 1222 | 
             
                ret = []
         | 
| 1223 | 
             
                prev_idx = 0
         | 
| 1224 | 
            +
                for t_idx in range(mem_freq, len(ids), mem_freq):
         | 
| 1225 | 
            +
                    ret.extend(ids[prev_idx:t_idx])
         | 
| 1226 | 
             
                    ret.append(mem_id)
         | 
| 1227 | 
             
                    prev_idx = t_idx
         | 
| 1228 | 
            +
                ret.extend(ids[prev_idx:])
         | 
| 1229 | 
             
                # drop attention_mask
         | 
| 1230 | 
             
                return {"input_ids": ret}
         | 
| 1231 | 
            +
             | 
| 1232 | 
            +
             | 
| 1233 | 
            +
            def patch_llama_with_landmark_attn():
         | 
| 1234 | 
            +
                import transformers
         | 
| 1235 | 
            +
             | 
| 1236 | 
            +
                transformers.models.llama.modeling_llama.LlamaForCausalLM = LlamaForCausalLM
         | 
| 1237 | 
            +
                transformers.models.llama.modeling_llama.LlamaModel = LlamaModel
         | 
| 1238 | 
            +
                transformers.models.llama.modeling_llama.LlamaAttention = LlamaAttention
         | 
| 1239 | 
            +
                transformers.models.llama.modeling_llama.LlamaDecoderLayer = LlamaDecoderLayer
         | 
| 1240 | 
            +
                transformers.models.llama.modeling_llama.apply_rotary_pos_emb = apply_rotary_pos_emb
         | 
| 1241 | 
            +
             | 
| 1242 | 
            +
             | 
| 1243 | 
            +
            def set_model_mem_id(model: LlamaForCausalLM, tokenizer: LlamaTokenizer):
         | 
| 1244 | 
            +
                mem_id = tokenizer.convert_tokens_to_ids(MEM_TOKEN)
         | 
| 1245 | 
            +
                model.set_mem_id(mem_id)
         | 
| 1246 | 
            +
             | 
| 1247 | 
            +
             | 
| 1248 | 
            +
            def get_mem_id(tokenizer: LlamaTokenizer):
         | 
| 1249 | 
            +
                return tokenizer.convert_tokens_to_ids(MEM_TOKEN)
         | 
    	
        src/axolotl/utils/models.py
    CHANGED
    
    | @@ -19,15 +19,6 @@ from transformers import (  # noqa: F401 | |
| 19 | 
             
                LlamaConfig,
         | 
| 20 | 
             
            )
         | 
| 21 |  | 
| 22 | 
            -
            try:
         | 
| 23 | 
            -
                from transformers import (  # pylint: disable=unused-import  # noqa: F401
         | 
| 24 | 
            -
                    LlamaForCausalLM,
         | 
| 25 | 
            -
                )
         | 
| 26 | 
            -
            except ImportError:
         | 
| 27 | 
            -
                logging.warning(
         | 
| 28 | 
            -
                    "This version of transformers does not support Llama. Consider upgrading."
         | 
| 29 | 
            -
                )
         | 
| 30 | 
            -
             | 
| 31 | 
             
            from axolotl.prompt_tokenizers import LLAMA_DEFAULT_PAD_TOKEN
         | 
| 32 |  | 
| 33 | 
             
            if TYPE_CHECKING:
         | 
| @@ -84,7 +75,6 @@ def load_model( | |
| 84 | 
             
                Load a model from a base model and a model type.
         | 
| 85 | 
             
                """
         | 
| 86 |  | 
| 87 | 
            -
                global LlamaForCausalLM  # pylint: disable=global-statement
         | 
| 88 | 
             
                # TODO refactor as a kwarg
         | 
| 89 | 
             
                load_in_8bit = cfg.load_in_8bit
         | 
| 90 | 
             
                cfg.is_llama_derived_model = "llama" in base_model or (
         | 
| @@ -112,14 +102,15 @@ def load_model( | |
| 112 | 
             
                    logging.info("patching with sdp attention")
         | 
| 113 | 
             
                    hijack_llama_sdp_attention()
         | 
| 114 | 
             
                elif cfg.is_llama_derived_model and cfg.landmark_attention:
         | 
| 115 | 
            -
                    from axolotl.monkeypatch.llama_landmark_attn import ( | 
| 116 | 
             
                        MEM_TOKEN,
         | 
| 117 | 
            -
                         | 
| 118 | 
             
                    )
         | 
| 119 |  | 
| 120 | 
             
                    logging.info("patching with landmark attention")
         | 
|  | |
| 121 |  | 
| 122 | 
            -
                    #  | 
| 123 | 
             
                    tokenizer.add_special_tokens({"additional_special_tokens": [MEM_TOKEN]})
         | 
| 124 |  | 
| 125 | 
             
                if cfg.is_llama_derived_model and cfg.xpos_rope:
         | 
| @@ -204,7 +195,9 @@ def load_model( | |
| 204 | 
             
                            else True,
         | 
| 205 | 
             
                        )
         | 
| 206 | 
             
                        load_in_8bit = False
         | 
| 207 | 
            -
                    elif cfg.is_llama_derived_model | 
|  | |
|  | |
| 208 | 
             
                        config = LlamaConfig.from_pretrained(base_model_config)
         | 
| 209 | 
             
                        model = LlamaForCausalLM.from_pretrained(
         | 
| 210 | 
             
                            base_model,
         | 
|  | |
| 19 | 
             
                LlamaConfig,
         | 
| 20 | 
             
            )
         | 
| 21 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 22 | 
             
            from axolotl.prompt_tokenizers import LLAMA_DEFAULT_PAD_TOKEN
         | 
| 23 |  | 
| 24 | 
             
            if TYPE_CHECKING:
         | 
|  | |
| 75 | 
             
                Load a model from a base model and a model type.
         | 
| 76 | 
             
                """
         | 
| 77 |  | 
|  | |
| 78 | 
             
                # TODO refactor as a kwarg
         | 
| 79 | 
             
                load_in_8bit = cfg.load_in_8bit
         | 
| 80 | 
             
                cfg.is_llama_derived_model = "llama" in base_model or (
         | 
|  | |
| 102 | 
             
                    logging.info("patching with sdp attention")
         | 
| 103 | 
             
                    hijack_llama_sdp_attention()
         | 
| 104 | 
             
                elif cfg.is_llama_derived_model and cfg.landmark_attention:
         | 
| 105 | 
            +
                    from axolotl.monkeypatch.llama_landmark_attn import (
         | 
| 106 | 
             
                        MEM_TOKEN,
         | 
| 107 | 
            +
                        patch_llama_with_landmark_attn,
         | 
| 108 | 
             
                    )
         | 
| 109 |  | 
| 110 | 
             
                    logging.info("patching with landmark attention")
         | 
| 111 | 
            +
                    patch_llama_with_landmark_attn()
         | 
| 112 |  | 
| 113 | 
            +
                    # Note: This might overwrite previous additional_special_tokens
         | 
| 114 | 
             
                    tokenizer.add_special_tokens({"additional_special_tokens": [MEM_TOKEN]})
         | 
| 115 |  | 
| 116 | 
             
                if cfg.is_llama_derived_model and cfg.xpos_rope:
         | 
|  | |
| 195 | 
             
                            else True,
         | 
| 196 | 
             
                        )
         | 
| 197 | 
             
                        load_in_8bit = False
         | 
| 198 | 
            +
                    elif cfg.is_llama_derived_model:
         | 
| 199 | 
            +
                        from transformers import LlamaForCausalLM
         | 
| 200 | 
            +
             | 
| 201 | 
             
                        config = LlamaConfig.from_pretrained(base_model_config)
         | 
| 202 | 
             
                        model = LlamaForCausalLM.from_pretrained(
         | 
| 203 | 
             
                            base_model,
         | 
    	
        src/axolotl/utils/trainer.py
    CHANGED
    
    | @@ -239,16 +239,19 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): | |
| 239 | 
             
                if cfg.is_llama_derived_model and cfg.landmark_attention:
         | 
| 240 | 
             
                    from functools import partial
         | 
| 241 |  | 
| 242 | 
            -
                    from axolotl.monkeypatch.llama_landmark_attn import  | 
|  | |
|  | |
|  | |
|  | |
| 243 |  | 
| 244 | 
            -
                     | 
| 245 | 
            -
                    model.set_mem_id(mem_id)
         | 
| 246 |  | 
| 247 | 
             
                    logging.info("Adding landmark attention tokens to dataset")
         | 
| 248 |  | 
| 249 | 
             
                    for dataset in [train_dataset, eval_dataset]:
         | 
| 250 | 
             
                        dataset = dataset.map(
         | 
| 251 | 
            -
                            partial(add_mem_tokens, mem_freq=50, mem_id= | 
| 252 | 
             
                            batched=False,
         | 
| 253 | 
             
                            num_proc=32,
         | 
| 254 | 
             
                        )
         | 
|  | |
| 239 | 
             
                if cfg.is_llama_derived_model and cfg.landmark_attention:
         | 
| 240 | 
             
                    from functools import partial
         | 
| 241 |  | 
| 242 | 
            +
                    from axolotl.monkeypatch.llama_landmark_attn import (
         | 
| 243 | 
            +
                        add_mem_tokens,
         | 
| 244 | 
            +
                        get_mem_id,
         | 
| 245 | 
            +
                        set_model_mem_id,
         | 
| 246 | 
            +
                    )
         | 
| 247 |  | 
| 248 | 
            +
                    set_model_mem_id(model, tokenizer)
         | 
|  | |
| 249 |  | 
| 250 | 
             
                    logging.info("Adding landmark attention tokens to dataset")
         | 
| 251 |  | 
| 252 | 
             
                    for dataset in [train_dataset, eval_dataset]:
         | 
| 253 | 
             
                        dataset = dataset.map(
         | 
| 254 | 
            +
                            partial(add_mem_tokens, mem_freq=50, mem_id=get_mem_id(tokenizer)),
         | 
| 255 | 
             
                            batched=False,
         | 
| 256 | 
             
                            num_proc=32,
         | 
| 257 | 
             
                        )
         | 
