diff --git "a/unsloth-main/unsloth-main/unsloth/models/llama.py" "b/unsloth-main/unsloth-main/unsloth/models/llama.py"
new file mode 100644--- /dev/null
+++ "b/unsloth-main/unsloth-main/unsloth/models/llama.py"
@@ -0,0 +1,2548 @@
+# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import torch
+import gc
+import math
+from typing import Optional, Tuple, List, Union
+from ._utils import *
+from ._utils import __version__
+from torch.nn.functional import scaled_dot_product_attention
+from transformers import __version__ as transformers_version
+from transformers.models.llama.modeling_llama import (
+    logger,
+    BaseModelOutputWithPast,
+    CausalLMOutputWithPast,
+)
+from transformers.modeling_attn_mask_utils import (
+    _prepare_4d_causal_attention_mask_for_sdpa,
+)
+from ..kernels import *
+from ..tokenizer_utils import *
+if HAS_FLASH_ATTENTION:
+    from flash_attn import flash_attn_func
+
+# Final patching code
+from transformers.models.llama.modeling_llama import (
+    LlamaAttention,
+    LlamaDecoderLayer,
+    LlamaModel,
+    LlamaForCausalLM,
+)
+
+# For Pytorch 2.1.1
+try:
+    from transformers.models.llama.modeling_llama import (
+        LlamaSdpaAttention,
+        LlamaFlashAttention2,
+    )
+except:
+    LlamaSdpaAttention   = LlamaAttention
+    LlamaFlashAttention2 = LlamaAttention
+pass
+
+from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, AutoConfig
+from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING
+from transformers import set_seed as transformers_set_seed
+from peft import LoraConfig, TaskType, get_peft_model as _get_peft_model
+from peft import PeftModelForCausalLM
+from bitsandbytes.nn import Linear4bit as Bnb_Linear4bit
+from peft.tuners.lora import Linear4bit as Peft_Linear4bit
+from ..save import patch_saving_functions
+import re, os, inspect, math, sys
+from huggingface_hub.utils._token import get_token
+
+
+def original_apply_qkv(self, X):
+    Q = self.q_proj(X)
+    K = self.k_proj(X)
+    V = self.v_proj(X)
+    return Q, K, V
+pass
+
+
+def original_apply_o(self, X):
+    O = self.o_proj(X)
+    return O
+pass
+
+from math import sqrt as math_sqrt
+KV_CACHE_INCREMENT = 256 # KV Cache update size
+torch_nn_functional_softmax = torch.nn.functional.softmax
+
+# Fix new HF's inference code
+def _fast_prepare_inputs_for_generation(self, input_ids, **kwargs,):
+    if "past_key_values" in kwargs:
+        input_ids = input_ids[:,[-1]]
+        kwargs["attention_mask"] = kwargs["attention_mask"][:,[-1]]
+    if "cache_position" in kwargs:
+        kwargs["position_ids"] = kwargs["cache_position"]
+    return { "input_ids" : input_ids, **kwargs, }
+pass
+
+
+def fix_prepare_inputs_for_generation(module):
+    # Fix prepare_inputs_for_generation
+    if hasattr(module, "prepare_inputs_for_generation"):
+        module.prepare_inputs_for_generation = _fast_prepare_inputs_for_generation
+    pass
+pass
+
+torch_matmul = torch.matmul
+def LlamaAttention_fast_forward_inference(
+    self,
+    hidden_states:  torch.Tensor,
+    past_key_value: Optional[Tuple[torch.Tensor]],
+    position_ids,
+    do_prefill = False,
+    attention_mask = None,
+):
+    """
+        https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L406
+        Fast inference using KV cache.
+        QK^T can be computed in 4 chunks
+
+        [Q, q] @ [K, k].T where q, k are the new tokens.
+        [QK^T, Qk^T]
+        [qK^T, qk^T]
+
+        Since the attention mask wipes Qk^T, we just get
+        [QK^T,    0]
+        [qK^T, qk^T]
+
+        Since softmax is row-wise, we get
+        softmax([QK^T,    0])
+        softmax([qK^T, qk^T])
+
+        We then multiply by   [V]
+                              [v]
+        softmax([QK^T,    0]) [softmax(QK^T)V] *
+        softmax([qK^T, qk^T]) [softmax([qK^T, qk^T]) @ [V, v]]
+
+        But notice * [softmax(QK^T)V] is just the last attention.
+        We just need to compute the last final row.
+
+        This means we can pass in a row of Q, but we need to
+        remember K and V, which are called the KV cache.
+    """
+    Xn = hidden_states
+    bsz, _, hd = hidden_states.size()
+    K1, V1 = past_key_value
+    dtype = Xn.dtype
+
+    n_heads    = self.num_heads
+    n_groups   = self.num_key_value_groups
+    n_kv_heads = self.num_key_value_heads
+    head_dim   = self.head_dim
+    attention_size = n_heads*head_dim
+    # assert(n_kv_heads * n_groups == n_heads)
+    seq_len = K1.shape[-2]
+    kv_seq_len = seq_len + 1
+
+    # Prefill phase
+    # if not hasattr(self, "paged_attention"):
+    if do_prefill:
+        self.paged_attention = torch.empty((KV_CACHE_INCREMENT+seq_len+1, 2, bsz, n_kv_heads, head_dim), dtype = dtype, device = "cuda:0")
+        self.paged_attention_K = self.paged_attention[:,0]
+        self.paged_attention_V = self.paged_attention[:,1]
+        self.paged_attention_K[:seq_len] = K1.permute(2, 0, 1, 3)
+        self.paged_attention_V[:seq_len] = V1.permute(2, 0, 1, 3)
+        self.temp_QA = torch.empty((2, bsz, 1, attention_size), dtype = dtype, device = "cuda:0")
+        self.temp_KV = torch.empty((2, bsz, 1, n_kv_heads*head_dim), dtype = dtype, device = "cuda:0")
+        self.RH_Q = torch.empty((bsz, n_heads, 1, head_dim), dtype = dtype, device = "cuda:0")
+        
+        # Mistral Nemo 12b has weird dimensions
+        if attention_size != self.hidden_size:
+            self.temp_O = torch.empty((1, bsz, self.hidden_size), dtype = dtype, device = "cuda:0")
+        else:
+            self.temp_O = self.temp_QA[1][:,:,:self.hidden_size]
+        pass
+        
+        self.attention = torch.empty((bsz, n_heads, 1, KV_CACHE_INCREMENT+seq_len), dtype = dtype, device = "cuda:0")
+        self.scalar = 1.0 / math_sqrt(self.head_dim)
+        self.half_head_dim = head_dim // 2
+    elif kv_seq_len >= self.paged_attention.shape[0]:
+        self.paged_attention.resize_((self.paged_attention.shape[0]+KV_CACHE_INCREMENT, 2, bsz, n_kv_heads, head_dim))
+        self.paged_attention_K = self.paged_attention[:,0]
+        self.paged_attention_V = self.paged_attention[:,1]
+        self.attention.resize_((bsz, n_heads, 1, self.attention.shape[-1]+KV_CACHE_INCREMENT))
+    pass
+
+    Qn = fast_linear_forward(self.q_proj, Xn, out = self.temp_QA[0])
+    Kn = fast_linear_forward(self.k_proj, Xn, out = self.temp_KV[0])
+    Vn = fast_linear_forward(self.v_proj, Xn, out = self.temp_KV[1])
+    Qn = Qn.view(bsz, 1, n_heads,    head_dim).transpose(1, 2)
+    Kn = Kn.view(bsz, 1, n_kv_heads, head_dim).transpose(1, 2)
+    Vn = Vn.view(bsz, 1, n_kv_heads, head_dim).transpose(1, 2)
+
+    # cos, sin = self.rotary_emb(Vn, seq_len = kv_seq_len)
+    # Qn, Kn = inplace_rope_embedding(Qn, Kn, cos, sin, position_ids)
+    cos, sin = self.rotary_emb.get_cached(kv_seq_len)
+    cos = cos[position_ids].unsqueeze(1)
+    sin = sin[position_ids].unsqueeze(1)
+    h = self.half_head_dim
+
+    RH_Q = self.RH_Q
+    RH_Q[:,:,:,:h] = Qn[:,:,:,h:]
+    RH_Q[:,:,:,h:] = Qn[:,:,:,:h]
+    torch.neg(RH_Q[:,:,:,:h], out = RH_Q[:,:,:,:h])
+    Qn *= cos
+    Qn.addcmul_(RH_Q, sin)
+
+    RH_K = RH_Q[:,:n_kv_heads,:,:] # torch.empty((n_kv_heads, 1, head_dim), dtype = dtype, device = "cuda:0")
+    RH_K[:,:,:,:h] = Kn[:,:,:,h:]
+    RH_K[:,:,:,h:] = Kn[:,:,:,:h]
+    torch.neg(RH_K[:,:,:,:h], out = RH_K[:,:,:,:h])
+    Kn *= cos
+    Kn.addcmul_(RH_K, sin)
+    
+    # New KV cache
+    # Kn = torch.cat([K1, Kn], dim = 2)
+    # Vn = torch.cat([V1, Vn], dim = 2)
+    self.paged_attention_K[seq_len] = Kn.permute(2, 0, 1, 3)
+    self.paged_attention_V[seq_len] = Vn.permute(2, 0, 1, 3)
+    Kn = self.paged_attention_K[:kv_seq_len].permute(1, 2, 0, 3)
+    Vn = self.paged_attention_V[:kv_seq_len].permute(1, 2, 0, 3)
+
+    # Handle sliding windows
+    sliding_window = getattr(self.config, "sliding_window", None)
+    if sliding_window is not None and kv_seq_len > sliding_window:
+        # From https://github.com/huggingface/transformers/blob/main/src/transformers/models/mistral/modeling_mistral.py#L193
+        slicing_tokens = 1 - sliding_window
+        Knn = Kn[:, :, slicing_tokens:, :]#.contiguous()
+        Vnn = Vn[:, :, slicing_tokens:, :]#.contiguous()
+    else:
+        Knn, Vnn = Kn, Vn
+    pass
+
+    # Grouped query attention
+    _, _, cached_len, _ = Knn.shape
+    if n_groups != 1:
+        Knn = Knn[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, cached_len, head_dim)
+        Vnn = Vnn[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, cached_len, head_dim)
+        Knn = Knn.reshape(bsz, n_heads, cached_len, head_dim)
+        Vnn = Vnn.reshape(bsz, n_heads, cached_len, head_dim)
+    pass
+    # else:
+    #     Knn, Vnn = Knn, Vnn
+    # pass
+
+    # Attention
+    if bsz == 1:
+        Qn *= self.scalar # See https://github.com/ggerganov/llama.cpp/issues/7805#issuecomment-2153349963
+        # It seems like doing (Q * scalar) @ K is better than (Q @ K) * scalar to stop overflows
+        A = torch_matmul(Qn, Knn.transpose(2, 3), out = self.attention[:,:,:,:cached_len])
+        # if attention_mask is not None: A += attention_mask # Must add attention_mask for batched
+        A[:] = torch_nn_functional_softmax(A, dim = -1, dtype = torch.float32)#.to(A.dtype)
+        A = torch_matmul(A, Vnn, out = Qn)
+    else:
+        A = scaled_dot_product_attention(Qn, Knn, Vnn, attn_mask = attention_mask, is_causal = False)
+    pass
+    A = A.transpose(1, 2)
+    A = A.reshape(bsz, 1, attention_size)
+    A = fast_linear_forward(self.o_proj, A, out = self.temp_O)
+    return A, (Kn, Vn)
+pass
+
+
+torch_nn_functional_silu = torch.nn.functional.silu
+def fast_swiglu_inference(self, X):
+    # gate = self.gate_proj(X)
+    # up   = self.up_proj(X)
+    bsz, _, hd = X.shape
+    # mlp_size = self.config.intermediate_size
+    # temp = torch.empty((2, bsz, 1, mlp_size), dtype = X.dtype, device = "cuda:0")
+
+    gate = fast_linear_forward(self.gate_proj, X)#, out = temp[0])
+    up   = fast_linear_forward(self.  up_proj, X)#, out = temp[1])
+    gate = torch_nn_functional_silu(gate, inplace = True)
+    gate *= up
+
+    # X = self.down_proj(gate)
+    down = fast_linear_forward(self.down_proj, gate, out = up[:,:,:hd])
+    return down
+pass
+
+
+def fast_rms_layernorm_inference(self, X):
+    old_dtype = X.dtype
+    XX = X.to(torch.float32)
+    variance = XX.square().mean(-1, keepdim = True)
+    variance += self.variance_epsilon
+    XX *= variance.rsqrt_()
+    X = XX.to(old_dtype) # Must preserve due to residual
+    X *= self.weight
+    return X
+pass
+
+
+def fast_rms_layernorm_inference_gemma(self, X, out_weight = None):
+    XX = X.to(torch.float32)
+    variance = XX.square().mean(-1, keepdim = True)
+    variance += self.variance_epsilon
+    XX *= variance.rsqrt_()
+
+    if out_weight is None:
+        out_weight = self.weight + 1.0
+    else:
+        out_weight[:] = self.weight
+        out_weight += 1.0
+    pass
+
+    XX *= out_weight
+    return XX.to(X.dtype)
+pass
+
+
+# Normal layernorm with mean removal
+@torch.compile(fullgraph = False, dynamic = True, options = torch_compile_options)
+def fast_layernorm_compiled(layernorm, X):
+    old_dtype = X.dtype
+    X = X.float()
+    mean = X.mean(-1, keepdim = True)
+    Xbar = X - mean
+    X = Xbar * torch.rsqrt(Xbar.square().mean(-1, keepdim = True) + \
+        layernorm.variance_epsilon) * \
+        layernorm.weight.float()
+    return X.to(old_dtype)
+pass
+
+
+# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L320
+def LlamaAttention_fast_forward(
+    self,
+    hidden_states:        torch.Tensor,
+    causal_mask:          Optional[xformers.attn_bias.BlockDiagonalCausalMask] = None,
+    attention_mask:       Optional[torch.Tensor] = None,
+    position_ids:         Optional[torch.LongTensor] = None,
+    past_key_value:       Optional[Tuple[torch.Tensor]] = None,
+    output_attentions:    bool = False,
+    use_cache:            bool = False,
+    padding_mask:         Optional[torch.LongTensor] = None,
+    *args, **kwargs,
+) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+    
+    # Clear inference
+    if hasattr(self, "paged_attention"):
+        del self.paged_attention_K
+        del self.paged_attention_V
+        del self.paged_attention
+        del self.temp_QA
+        del self.temp_KV
+        del self.RH_Q
+        del self.attention
+    pass
+
+    bsz, q_len, _ = hidden_states.size()
+
+    n_heads    = self.num_heads
+    n_groups   = self.num_key_value_groups
+    n_kv_heads = self.num_key_value_heads
+    head_dim   = self.head_dim
+    assert(n_kv_heads * n_groups == n_heads)
+
+    Q, K, V = self.apply_qkv(self, hidden_states)
+    Q = Q.view(bsz, q_len, n_heads,    head_dim).transpose(1, 2)
+    K = K.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2)
+    V = V.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2)
+
+    kv_seq_len = K.shape[-2]
+    if past_key_value is not None:
+        kv_seq_len += past_key_value[0].shape[-2]
+
+    # Extend RoPE dynamically to fit in VRAM
+    rotary_emb = self.rotary_emb
+    rotary_emb.extend_rope_embedding(V, seq_len = kv_seq_len)
+
+    if position_ids is None:
+        # Useful for LongRoPE
+        cos, sin = rotary_emb.get_cached(kv_seq_len)
+        # cos = self.rotary_emb.cos_cached
+        # sin = self.rotary_emb.sin_cached
+        Q, K = fast_rope_embedding(Q, K, cos, sin)
+    else:
+        cos, sin = rotary_emb(V, seq_len = kv_seq_len)
+        Q, K = inplace_rope_embedding(Q, K, cos, sin, position_ids)
+    pass
+
+    if past_key_value is not None:
+        K = torch.cat([past_key_value[0], K], dim = 2)
+        V = torch.cat([past_key_value[1], V], dim = 2)
+    pass
+    past_key_value = (K, V) if use_cache else None
+
+    # Attention module
+    if (not HAS_FLASH_ATTENTION and attention_mask is None):
+        # Xformers memory efficient attention
+        # Also has Flash Attention v2 dispatching
+        Q = Q.transpose(1, 2)
+        K = K.transpose(1, 2)
+        V = V.transpose(1, 2)
+
+        # Group query attention
+        if n_groups != 1:
+            K = K  .view(bsz, kv_seq_len, n_kv_heads,        1, head_dim)
+            V = V  .view(bsz, kv_seq_len, n_kv_heads,        1, head_dim)
+            K = K.expand(bsz, kv_seq_len, n_kv_heads, n_groups, head_dim)
+            V = V.expand(bsz, kv_seq_len, n_kv_heads, n_groups, head_dim)
+            if hidden_states.requires_grad:
+                K = K.reshape(bsz, kv_seq_len, n_heads, head_dim)
+                V = V.reshape(bsz, kv_seq_len, n_heads, head_dim)
+            else:
+                Q = Q.view(bsz, q_len, n_kv_heads, n_groups, head_dim)
+        pass
+        A = xformers_attention(Q, K, V, attn_bias = causal_mask)
+        A = A.view(bsz, q_len, n_heads, head_dim)
+
+    elif HAS_FLASH_ATTENTION and attention_mask is None:
+        Q = Q.transpose(1, 2)
+        K = K.transpose(1, 2)
+        V = V.transpose(1, 2)
+        A = flash_attn_func(Q, K, V, causal = True)
+    else:
+        # Grouped query attention
+        if n_groups != 1:
+            K = K[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, kv_seq_len, head_dim)
+            V = V[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, kv_seq_len, head_dim)
+            K = K.reshape(bsz, n_heads, kv_seq_len, head_dim)
+            V = V.reshape(bsz, n_heads, kv_seq_len, head_dim)
+        pass
+        # Must be contiguous or else results are False!
+        # https://github.com/pytorch/pytorch/issues/112577
+        Q, K, V = Q.contiguous(), K.contiguous(), V.contiguous()
+        # Needs (batch_size, n_heads, seq_len, head_dim)
+        # is_casual and attention_mask must not be both set!
+        A = scaled_dot_product_attention(Q, K, V, attn_mask = attention_mask, is_causal = False)
+        # Go back to (batch_size, seq_len, n_heads, head_dim)
+        A = A.transpose(1, 2).contiguous()
+    pass
+    attn_output = A.reshape(bsz, q_len, n_heads*head_dim)
+    attn_output = self.apply_o(self, attn_output)
+    attn_weights = None
+    return attn_output, attn_weights, past_key_value
+pass
+
+
+# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L590
+def LlamaDecoderLayer_fast_forward(
+    self,
+    hidden_states:        torch.Tensor,
+    causal_mask           = None,
+    attention_mask:       Optional[torch.Tensor] = None,
+    position_ids:         Optional[torch.LongTensor] = None,
+    past_key_value:       Optional[Tuple[torch.Tensor]] = None,
+    output_attentions:    Optional[bool] = False,
+    use_cache:            Optional[bool] = False,
+    padding_mask:         Optional[torch.LongTensor] = None,
+    *args, **kwargs,
+) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
+    """
+    Args:
+        hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+        attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
+            `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
+        output_attentions (`bool`, *optional*):
+            Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+            returned tensors for more detail.
+        use_cache (`bool`, *optional*):
+            If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
+            (see `past_key_values`).
+        past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
+    """
+    if use_cache and hasattr(self, "_flag_for_generation"):
+        residual = hidden_states
+        hidden_states = fast_rms_layernorm_inference(self.input_layernorm, hidden_states)
+        hidden_states, self_attn_weights, present_key_value = self.self_attn(
+            hidden_states=hidden_states,
+            causal_mask=causal_mask,
+            attention_mask=attention_mask,
+            position_ids=position_ids,
+            past_key_value=past_key_value,
+            output_attentions=output_attentions,
+            use_cache=use_cache,
+            padding_mask=padding_mask,
+        )
+        hidden_states += residual
+
+        # Fully Connected
+        residual = hidden_states
+        hidden_states = fast_rms_layernorm_inference(self.post_attention_layernorm, hidden_states)
+        hidden_states = fast_swiglu_inference(self.mlp, hidden_states)
+        hidden_states += residual
+    else:
+        residual = hidden_states
+        hidden_states = fast_rms_layernorm(self.input_layernorm, hidden_states)
+        hidden_states, self_attn_weights, present_key_value = self.self_attn(
+            hidden_states=hidden_states,
+            causal_mask=causal_mask,
+            attention_mask=attention_mask,
+            position_ids=position_ids,
+            past_key_value=past_key_value,
+            output_attentions=output_attentions,
+            use_cache=use_cache,
+            padding_mask=padding_mask,
+        )
+        hidden_states = residual + hidden_states
+
+        # Fully Connected
+        residual = hidden_states
+        hidden_states = fast_rms_layernorm(self.post_attention_layernorm, hidden_states)
+        hidden_states = self.mlp(hidden_states)
+        hidden_states = residual + hidden_states
+    pass
+
+    outputs = (hidden_states,)
+    if output_attentions: outputs += (self_attn_weights,)
+    if use_cache: outputs += (present_key_value,)
+    return outputs
+pass
+
+
+# https://github.com/unslothai/unsloth/issues/404#issuecomment-2323473452
+__DTYPE_MAP = {
+    "float32": torch.float32,
+    torch.float32: torch.float32,
+    "float16": torch.float16,
+    torch.float16: torch.float16,
+    "bfloat16": torch.bfloat16,
+    torch.bfloat16: torch.bfloat16,
+}
+
+# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L825
+def LlamaModel_fast_forward(
+    self,
+    input_ids:            torch.LongTensor,
+    causal_mask:          Optional[xformers.attn_bias.BlockDiagonalCausalMask] = None,
+    attention_mask:       Optional[torch.Tensor] = None,
+    position_ids:         Optional[torch.LongTensor] = None,
+    past_key_values:      Optional[List[torch.FloatTensor]] = None,
+    inputs_embeds:        Optional[torch.FloatTensor] = None,
+    use_cache:            Optional[bool] = None,
+    output_attentions:    Optional[bool] = None,
+    output_hidden_states: Optional[bool] = None,
+    return_dict:          Optional[bool] = None,
+    *args, **kwargs,
+) -> Union[Tuple, BaseModelOutputWithPast]:
+    
+    output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+    assert(output_attentions is False)
+    output_hidden_states = (
+        output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+    )
+    use_cache = use_cache if use_cache is not None else self.config.use_cache
+
+    return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+    # retrieve input_ids and inputs_embeds
+    if input_ids is not None and inputs_embeds is not None:
+        raise ValueError("Unsloth: You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
+    elif input_ids is not None:
+        batch_size, seq_length = input_ids.shape
+    elif inputs_embeds is not None:
+        batch_size, seq_length, _ = inputs_embeds.shape
+    else:
+        raise ValueError("Unsloth: You have to specify either decoder_input_ids or decoder_inputs_embeds")
+
+    seq_length_with_past = seq_length
+
+    # Fix out of bounds tokenization
+    if hasattr(self, "max_seq_length"):
+        if seq_length > self.max_seq_length:
+            logger.warning_once(
+                f"Unsloth: Input IDs of length {seq_length} > the model's max sequence length of {self.max_seq_length}.\n"\
+                "We shall truncate it ourselves. It's imperative if you correct this issue first."
+            )
+        if input_ids is not None:
+            input_ids = input_ids[:,:self.max_seq_length]
+        elif inputs_embeds is not None:
+            inputs_embeds = inputs_embeds[:,:self.max_seq_length,:]
+        pass
+    pass
+    
+    past_key_values_length = 0
+
+    if past_key_values is not None:
+        past_key_values_length = past_key_values[0][0].shape[2]
+        seq_length_with_past = seq_length_with_past + past_key_values_length
+    pass
+
+    # We already handle KV cache position_ids ourselves.
+    if False:#(past_key_values_length != 0):
+        position_ids = torch.arange(
+            past_key_values_length, seq_length + past_key_values_length,
+            dtype  = torch.int32,
+            device = "cuda:0",
+        )
+        position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
+    elif position_ids is not None:
+        position_ids = position_ids.view(-1, seq_length).to(torch.int32)#.long()
+    else:
+        position_ids = None
+    pass
+
+    if position_ids is not None:
+        if position_ids.shape[0] != batch_size:
+            position_ids = position_ids.repeat((batch_size, 1))
+    pass
+
+    # Embed positions
+    if inputs_embeds is None:
+        inputs_embeds = self.embed_tokens(input_ids)
+
+    # inputs_embeds = inputs_embeds.to(self.config.torch_dtype)
+    torch_dtype = __DTYPE_MAP.get(self.config.torch_dtype, None)
+    if torch_dtype is not None:
+        inputs_embeds = inputs_embeds.to(torch_dtype)
+    else:
+        raise TypeError("Unsloth: torch_dtype for models is not bfloat16, float16 or float32!")
+    pass
+
+    # Normalized from Gemma
+    IS_GEMMA  = self.config.model_type.startswith("gemma")
+    IS_GEMMA2 = self.config.model_type.startswith("gemma2")
+    IS_COHERE = self.config.model_type.startswith("cohere")
+    train_embed_tokens = self.embed_tokens.weight.requires_grad
+
+    if IS_GEMMA:
+        # Match Gemma exactly by casting to bfloat16 / float16
+        # inputs_embeds *= math_sqrt(self.config.hidden_size)
+        # Ie 3072**0.5 = 55.5000 in bfloat16, whilst 55.4256 in float32
+        # &  2048**0.5 = 45.2500 in bfloat16, whilst 45.2548 in float32
+        normalizer = torch.tensor(math_sqrt(self.config.hidden_size), dtype = inputs_embeds.dtype)
+
+        if train_embed_tokens:
+            # Careful we must not do an inplace op!
+            inputs_embeds = inputs_embeds * normalizer
+        else:
+            inputs_requires_grad = inputs_embeds.requires_grad
+            if not inputs_embeds.is_leaf:
+                inputs_embeds = inputs_embeds.detach()
+                inputs_requires_grad = True
+            elif inputs_requires_grad:
+                inputs_embeds.requires_grad_(False)
+            pass
+            inputs_embeds *= normalizer
+            # inputs_embeds *= math_sqrt(self.config.hidden_size)
+            if inputs_requires_grad: inputs_embeds.requires_grad_(True)
+        pass
+    pass
+
+    # Fix up attention mask by setting elements to 0
+    # Specifically for DPO
+    if self._has_no_labels and (attention_mask is not None) and (past_key_values is None) and \
+        (not train_embed_tokens):
+        # Careful for inference the attention_mask is size (1, kv_seq_len)
+        # Whilst the input_embeds is size (1, 1, 4096)
+        inputs_requires_grad = inputs_embeds.requires_grad
+        if not inputs_embeds.is_leaf:
+            inputs_embeds = inputs_embeds.detach()
+            inputs_requires_grad = True
+        elif inputs_requires_grad:
+            inputs_embeds.requires_grad_(False)
+        pass
+        inputs_embeds *= attention_mask.unsqueeze(0).transpose(0, 1).transpose(1, 2)
+        if inputs_requires_grad: inputs_embeds.requires_grad_(True)
+    pass
+
+    # Ignore attention_mask
+    if attention_mask is None:
+        padding_mask = None
+    elif self.training:
+        attention_mask = None
+        padding_mask = None
+    else:
+        # if 0 in attention_mask:
+        #     padding_mask = attention_mask
+        # else:
+        padding_mask = None
+
+        attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
+            attention_mask,
+            (batch_size, seq_length),
+            inputs_embeds,
+            past_key_values_length,
+            sliding_window = getattr(self.config, "sliding_window", None),
+        )
+    pass
+
+    hidden_states = inputs_embeds
+
+    if past_key_values is None and self.training:
+        use_cache = False
+        # if use_cache:
+        #     logger.warning_once(
+        #         "Unsloth: `use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`"
+        #     )
+        #     use_cache = False
+    pass
+
+    # decoder layers
+    all_hidden_states = () if output_hidden_states else None
+    all_self_attns = () if output_attentions else None
+    next_decoder_cache = () if use_cache else None
+
+    # Gradient checkpointing methods (ie sqrt)
+    if hasattr(self, "_gradient_checkpointing_boundaries"):
+        boundaries = self._gradient_checkpointing_boundaries
+    else:
+        boundaries = None
+    pass
+
+    # Check checkpointing method
+    gradient_checkpointing = False
+    offloaded_gradient_checkpointing = False
+
+    if (self.gradient_checkpointing and self.training and not use_cache):
+
+        gradient_checkpointing = True
+
+        if output_attentions is False and hasattr(self, "_offloaded_gradient_checkpointing"):
+            offloaded_gradient_checkpointing = True
+    pass
+
+    # Gemma2 has alternating SWA and global attn
+    if IS_GEMMA2:
+        if HAS_FLASH_ATTENTION_SOFTCAPPING and attention_mask is None:
+            self.SWA_mask = True
+            self.GA_mask  = False
+        elif attention_mask is not None:
+            self.SWA_mask = _prepare_4d_causal_attention_mask_for_sdpa(
+                attention_mask,
+                (batch_size, seq_length),
+                inputs_embeds,
+                past_key_values_length,
+                sliding_window = self.config.sliding_window,
+            )
+            self.GA_mask = _prepare_4d_causal_attention_mask_for_sdpa(
+                attention_mask,
+                (batch_size, seq_length),
+                inputs_embeds,
+                past_key_values_length,
+                sliding_window = None,
+            )
+        elif not hasattr(self, "SWA_mask"):
+            if HAS_FLEX_ATTENTION:
+                # Use Flex Attention instead!
+                self.SWA_mask = create_flex_attention_sliding_window_mask(self.max_seq_length, self.config.sliding_window)
+                self.GA_mask  = create_flex_attention_causal_mask(self.max_seq_length)
+            else:
+                n = self.max_seq_length # self.config.max_position_embeddings
+                # masked_fill is making stuff slower!
+                # self. GA_mask = create_boolean_mask(n = n, sliding_window = 0)
+                # self.SWA_mask = create_boolean_mask(n = n, sliding_window = self.config.sliding_window)
+                from transformers.modeling_attn_mask_utils import AttentionMaskConverter
+                self.SWA_mask = AttentionMaskConverter(
+                    is_causal = True,
+                    sliding_window = self.config.sliding_window,
+                )\
+                    .to_causal_4d(1, n, n, dtype = inputs_embeds.dtype, device = "cuda:0",)\
+                    .squeeze(0).squeeze(0)
+
+                self.GA_mask = AttentionMaskConverter(
+                    is_causal = True,
+                )\
+                    .to_causal_4d(1, n, n, dtype = inputs_embeds.dtype, device = "cuda:0",)\
+                    .squeeze(0).squeeze(0)
+            pass
+        pass
+    pass
+
+    # Go through every layer!
+    for idx, decoder_layer in enumerate(self.layers):
+
+        if output_hidden_states: all_hidden_states += (hidden_states,)
+        past_key_value = past_key_values[idx] if past_key_values is not None else None
+
+        mask = causal_mask
+        if IS_GEMMA2: mask = self.SWA_mask if (idx % 2 == 0) else self.GA_mask
+
+        if offloaded_gradient_checkpointing:
+            hidden_states = Unsloth_Offloaded_Gradient_Checkpointer.apply(
+                decoder_layer,
+                hidden_states,
+                mask,
+                attention_mask,
+                position_ids,
+                past_key_values,
+                output_attentions,
+                use_cache,
+            )[0]
+
+        elif gradient_checkpointing:
+            def create_custom_forward(module):
+                def custom_forward(*inputs):
+                    return module(*inputs, past_key_value, output_attentions, padding_mask = padding_mask)
+                return custom_forward
+            pass
+
+            layer_outputs = torch.utils.checkpoint.checkpoint(
+                create_custom_forward(decoder_layer),
+                hidden_states,
+                mask,
+                attention_mask,
+                position_ids,
+                use_reentrant = True,
+                preserve_rng_state = False,
+            )
+            hidden_states = layer_outputs[0]
+
+        else:
+            layer_outputs = decoder_layer(
+                hidden_states,
+                causal_mask=mask,
+                attention_mask=attention_mask,
+                position_ids=position_ids,
+                past_key_value=past_key_value,
+                output_attentions=output_attentions,
+                use_cache=use_cache,
+                padding_mask=padding_mask,
+            )
+            hidden_states = layer_outputs[0]
+        pass
+
+        if use_cache: next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
+        if output_attentions: all_self_attns += (layer_outputs[1],)
+    pass
+
+    # Final layernorm
+    if use_cache:
+        hidden_states = \
+            (fast_rms_layernorm_inference_gemma if IS_GEMMA else fast_rms_layernorm_inference)\
+            (self.norm, hidden_states)
+    elif IS_COHERE:
+        hidden_states = self.norm(hidden_states)
+    else:
+        hidden_states = fast_rms_layernorm(self.norm, hidden_states, gemma = IS_GEMMA)
+    pass
+
+    if output_hidden_states: all_hidden_states += (hidden_states,)
+    next_cache = next_decoder_cache if use_cache else None
+
+    if not return_dict:
+        return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
+    return BaseModelOutputWithPast(
+        last_hidden_state=hidden_states,
+        past_key_values=next_cache,
+        hidden_states=all_hidden_states,
+        attentions=all_self_attns,
+    )
+pass
+
+
+# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L825
+def LlamaModel_fast_forward_inference(
+    self,
+    input_ids,
+    past_key_values,
+    position_ids,
+    attention_mask = None,
+):
+    input_ids = input_ids[:,:self.max_seq_length]
+    hidden_states = self.model.embed_tokens(input_ids)
+    hidden_states = hidden_states.to(self.config.torch_dtype)
+    bsz, q_len, hd = hidden_states.shape
+    seq_len = past_key_values[0][0].shape[-2]
+    if bsz != 1:
+        attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
+            attention_mask,
+            (bsz, q_len),
+            hidden_states,
+            seq_len,
+            sliding_window = getattr(self.config, "sliding_window", None),
+        )
+    else:
+        attention_mask = None
+    pass
+
+    next_decoder_cache = []
+    for idx, decoder_layer in enumerate(self.model.layers):
+        residual = hidden_states
+        hidden_states = fast_rms_layernorm_inference(decoder_layer.input_layernorm, hidden_states)
+        hidden_states, present_key_value = LlamaAttention_fast_forward_inference(
+            decoder_layer.self_attn,
+            hidden_states = hidden_states,
+            past_key_value = past_key_values[idx],
+            position_ids = position_ids,
+            attention_mask = attention_mask,
+            do_prefill = not hasattr(decoder_layer.self_attn, "paged_attention"),
+        )
+        hidden_states += residual
+
+        residual = hidden_states
+        hidden_states = fast_rms_layernorm_inference(decoder_layer.post_attention_layernorm, hidden_states)
+        hidden_states = fast_swiglu_inference(decoder_layer.mlp, hidden_states)
+        hidden_states += residual
+
+        next_decoder_cache.append(present_key_value)
+    pass
+    hidden_states = fast_rms_layernorm_inference(self.model.norm, hidden_states)
+
+    return BaseModelOutputWithPast(
+        last_hidden_state = hidden_states,
+        past_key_values = next_decoder_cache,
+        hidden_states = [],
+        attentions = [],
+    )
+pass
+
+
+def CausalLM_fast_forward(fast_forward_inference):
+    def _CausalLM_fast_forward(
+        self,
+        input_ids: torch.LongTensor = None,
+        causal_mask: Optional[xformers.attn_bias.BlockDiagonalCausalMask] = None,
+        attention_mask: Optional[torch.Tensor] = None,
+        position_ids: Optional[torch.LongTensor] = None,
+        past_key_values: Optional[List[torch.FloatTensor]] = None,
+        inputs_embeds: Optional[torch.FloatTensor] = None,
+        labels: Optional[torch.LongTensor] = None,
+        use_cache: Optional[bool] = None,
+        output_attentions: Optional[bool] = None,
+        output_hidden_states: Optional[bool] = None,
+        return_dict: Optional[bool] = None,
+        num_logits_to_keep: Optional[int] = 0,
+        *args, **kwargs,
+    ) -> Union[Tuple, CausalLMOutputWithPast]:
+        
+        if past_key_values is not None:
+            outputs = fast_forward_inference(
+                self,
+                input_ids,
+                past_key_values,
+                position_ids = position_ids,
+                attention_mask = attention_mask,
+            )
+        else:
+            causal_mask = xformers.attn_bias.LowerTriangularMask()
+
+            output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+            output_hidden_states = (
+                output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+            )
+            return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+            # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
+            self.model._has_no_labels = labels is None
+            outputs = self.model(
+                input_ids=input_ids,
+                causal_mask=causal_mask,
+                attention_mask=attention_mask,
+                position_ids=position_ids,
+                past_key_values=past_key_values,
+                inputs_embeds=inputs_embeds,
+                use_cache=use_cache,
+                output_attentions=output_attentions,
+                output_hidden_states=output_hidden_states,
+                return_dict=return_dict,
+            )
+        pass
+        hidden_states = outputs[0]
+        bsz, q_len, hd = hidden_states.shape
+        lm_head = self.lm_head.weight
+        if bsz == 1 and q_len == 1:
+            logits = torch.mv(lm_head, hidden_states.ravel().to(lm_head.dtype))
+            logits = logits.unsqueeze(0).unsqueeze(0)
+        elif num_logits_to_keep != 0:
+            logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :].to(lm_head.dtype))
+        else:
+            logits = self.lm_head(hidden_states.to(lm_head.dtype))
+        pass
+
+        torch_dtype = __DTYPE_MAP.get(self.config.torch_dtype, None)
+        if torch_dtype is not None:
+            logits = logits.to(torch_dtype)
+        else:
+            raise TypeError("Unsloth: torch_dtype for models is not bfloat16, float16 or float32!")
+        pass
+
+        loss = None
+        logit_softcapping = getattr(self.config, "final_logit_softcapping", 0)
+        logit_scaling     = getattr(self.config, "logit_scale", 0)
+        if labels is not None:
+            shift_logits = logits
+            if not hasattr(self, "extra_ignored_labels"):
+                # Fixes https://github.com/unslothai/unsloth/issues/10
+                self.extra_ignored_labels = torch.full((self.max_seq_length, 1), -100, device = "cuda:0")
+            pass
+            
+            shift_labels = torch.hstack((labels[..., 1:], self.extra_ignored_labels[:labels.shape[0]]))
+            loss = fast_cross_entropy_loss(
+                logits = shift_logits,
+                labels = shift_labels,
+                logit_softcapping = logit_softcapping,
+                logit_scaling     = logit_scaling,
+            )
+        else:
+            if logit_scaling != 0:
+                if logits.requires_grad:
+                    logits = logit_scaling * logits
+                else:
+                    logits *= logit_scaling
+                pass
+            pass
+            if logit_softcapping != 0:
+                if logits.requires_grad:
+                    logits = (1.0 / logit_softcapping) * logits
+                    logits = torch.tanh(logits)
+                    logits = logit_softcapping * logits
+                else:
+                    logits *= (1.0 / logit_softcapping)
+                    torch.tanh(logits, out = logits)
+                    logits *= logit_softcapping
+                pass
+            pass
+        pass
+
+        if not return_dict:
+            output = (logits,) + outputs[1:]
+            return (loss,) + output if loss is not None else output
+
+        return CausalLMOutputWithPast(
+            loss=loss,
+            logits=logits,
+            past_key_values=outputs.past_key_values,
+            hidden_states=outputs.hidden_states,
+            attentions=outputs.attentions,
+        )
+    pass
+    return _CausalLM_fast_forward
+pass
+
+
+@torch._disable_dynamo
+def PeftModelForCausalLM_fast_forward(
+    self,
+    input_ids=None,
+    causal_mask=None,
+    attention_mask=None,
+    inputs_embeds=None,
+    labels=None,
+    output_attentions=None,
+    output_hidden_states=None,
+    return_dict=None,
+    task_ids=None,
+    num_logits_to_keep=0,
+    **kwargs,
+):
+    return self.base_model(
+        input_ids=input_ids,
+        causal_mask=causal_mask,
+        attention_mask=attention_mask,
+        inputs_embeds=inputs_embeds,
+        labels=labels,
+        output_attentions=output_attentions,
+        output_hidden_states=output_hidden_states,
+        return_dict=return_dict,
+        num_logits_to_keep=num_logits_to_keep,
+        **kwargs,
+    )
+pass
+
+
+# Solves https://github.com/unslothai/unsloth/issues/168
+# Static KV Cache was introduced in 4.38.0, causing training to be much slower.
+# Inferene can now be CUDAGraphed, but we shall retain the old rotary embeddings.
+# https://github.com/huggingface/transformers/pull/27931
+# https://github.com/huggingface/transformers/blob/v4.37.2/src/transformers/models/llama/modeling_llama.py
+class LlamaRotaryEmbedding(torch.nn.Module):
+    # Fixes https://github.com/huggingface/transformers/pull/28837
+    # https://github.com/microsoft/DeepSpeed/issues/4932
+    # The precision of RoPE buffers is not correct, so we cast to int64.
+    def __init__(self, dim = None, max_position_embeddings=2048, base=10000, device=None,
+        config = None, # [TODO] Hack to pass in config - need to remove later
+    ):
+        super().__init__()
+        if config is not None:
+            # [TODO] Hack to pass in config - need to remove later
+            base = config.rope_theta
+            partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
+            dim = int((config.hidden_size // config.num_attention_heads))
+            device = "cuda"
+            max_position_embeddings = config.max_position_embeddings
+        pass
+
+        self.dim = dim
+        self.max_position_embeddings = max_position_embeddings
+        self.base = base
+        # Dynamic RoPE we first set it to a max of 4 * 8192 tokens then we iteratively grow this
+        self.current_rope_size = min(4 * 8192, self.max_position_embeddings)
+
+        # Build here to make `torch.jit.trace` work.
+        self._set_cos_sin_cache(seq_len=self.current_rope_size, device=device, dtype=torch.get_default_dtype())
+    pass
+
+    def _set_cos_sin_cache(self, seq_len, device, dtype):
+        # Note: on the original Llama codebase, these tensors are created on the target device (and not on CPU) and
+        # in FP32. They are applied (multiplied) in FP32 as well.
+        self.current_rope_size = seq_len
+        inv_freq = 1.0 / (
+            self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64, device="cpu").float() / self.dim)
+        )
+        t = torch.arange(self.current_rope_size, device="cpu", dtype=torch.int64).float()
+
+        freqs = torch.outer(t, inv_freq)
+        # Different from paper, but it uses a different permutation in order to obtain the same calculation
+        emb = torch.cat((freqs, freqs), dim=-1)
+        self.register_buffer("cos_cached", emb.cos().to(dtype=dtype, device=device, non_blocking=True), persistent=False)
+        self.register_buffer("sin_cached", emb.sin().to(dtype=dtype, device=device, non_blocking=True), persistent=False)
+    pass
+
+    def forward(self, x, position_ids=None, seq_len=None):
+        # x: [bs, num_attention_heads, seq_len, head_size]
+        if seq_len > self.current_rope_size:
+            self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
+
+        return (
+            self.cos_cached[:seq_len].to(dtype = x.dtype),
+            self.sin_cached[:seq_len].to(dtype = x.dtype),
+        )
+    pass
+
+    def get_cached(self, seq_len = None):
+        return self.cos_cached, self.sin_cached
+    pass
+
+    def extend_rope_embedding(self, x, seq_len):
+        if seq_len <= self.current_rope_size: return
+        # Iteratively grow by increments of 8192
+        self.current_rope_size = math.ceil(seq_len / 8192) * 8192
+        self._set_cos_sin_cache(self.current_rope_size, device = "cuda:0", dtype = x.dtype)
+    pass
+pass
+
+
+class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):
+    """LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
+    # Fixes https://github.com/huggingface/transformers/pull/28837
+    # https://github.com/microsoft/DeepSpeed/issues/4932
+    # The precision of RoPE buffers is not correct, so we cast to int64.
+    def __init__(self, dim = None, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0,
+        config = None, # [TODO] Hack to pass in config - need to remove later
+    ):
+        self.scaling_factor = scaling_factor
+        super().__init__(dim = dim, max_position_embeddings = max_position_embeddings, base = base, device = device, config = config)
+    pass
+
+    def _set_cos_sin_cache(self, seq_len, device, dtype):
+        self.current_rope_size = seq_len
+        inv_freq = 1.0 / (
+            self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64, device="cpu").float() / self.dim)
+        )
+        t = torch.arange(self.current_rope_size, device="cpu", dtype=torch.int64).float()
+        t = t / self.scaling_factor
+
+        freqs = torch.outer(t, inv_freq)
+        # Different from paper, but it uses a different permutation in order to obtain the same calculation
+        emb = torch.cat((freqs, freqs), dim=-1)
+        self.register_buffer("cos_cached", emb.cos().to(dtype=dtype, device=device, non_blocking=True), persistent=False)
+        self.register_buffer("sin_cached", emb.sin().to(dtype=dtype, device=device, non_blocking=True), persistent=False)
+    pass
+pass
+
+
+# See https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/rotary_embedding.py#L736
+# For Llama 3.1
+class LlamaExtendedRotaryEmbedding(torch.nn.Module):
+    def __init__(self, dim = None, max_position_embeddings=2048, base=10000, device=None,
+        config = None, # [TODO] Hack to pass in config - need to remove later
+    ):
+        super().__init__()
+        if config is not None:
+            # [TODO] Hack to pass in config - need to remove later
+            base = config.rope_theta
+            partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
+            dim = int((config.hidden_size // config.num_attention_heads))
+            device = "cuda"
+            max_position_embeddings = config.max_position_embeddings
+        pass
+
+        self.dim = dim
+        self.max_position_embeddings = max_position_embeddings
+        self.base = base
+        # Dynamic RoPE we first set it to a max of 4 * 8192 tokens then we iteratively grow this
+        self.current_rope_size = min(4 * 8192, self.max_position_embeddings)
+
+        # Normal Llama-3 RoPE
+        inv_freq = 1.0 / (
+            self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64, device="cpu").float() / self.dim)
+        )
+        inv_freq = self.apply_scaling(inv_freq)
+        self.register_buffer("inv_freq", inv_freq, persistent = False)
+
+        # Build here to make `torch.jit.trace` work.
+        self._set_cos_sin_cache(seq_len=self.current_rope_size, device=device, dtype=torch.get_default_dtype())
+    pass
+
+    def _set_cos_sin_cache(self, seq_len, device, dtype):
+        # Note: on the original Llama codebase, these tensors are created on the target device (and not on CPU) and
+        # in FP32. They are applied (multiplied) in FP32 as well.
+        self.current_rope_size = seq_len
+        
+        t = torch.arange(self.current_rope_size, device=self.inv_freq.device, dtype=torch.int64).float()
+
+        freqs = torch.outer(t, self.inv_freq)
+        # Different from paper, but it uses a different permutation in order to obtain the same calculation
+        emb = torch.cat((freqs, freqs), dim=-1)
+        self.register_buffer("cos_cached", emb.cos().to(dtype=dtype, device=device, non_blocking=True), persistent=False)
+        self.register_buffer("sin_cached", emb.sin().to(dtype=dtype, device=device, non_blocking=True), persistent=False)
+    pass
+
+    # From https://github.com/meta-llama/llama-models/blob/main/models/llama3_1/api/model.py#L41
+    def apply_scaling(self, freqs: torch.Tensor):
+        # Values obtained from grid search
+        scale_factor = 8
+        low_freq_factor = 1
+        high_freq_factor = 4
+        old_context_len = 8192  # original llama3 length
+
+        low_freq_wavelen = old_context_len / low_freq_factor
+        high_freq_wavelen = old_context_len / high_freq_factor
+        new_freqs = []
+        for freq in freqs:
+            wavelen = 2 * math.pi / freq
+            if wavelen < high_freq_wavelen:
+                new_freqs.append(freq)
+            elif wavelen > low_freq_wavelen:
+                new_freqs.append(freq / scale_factor)
+            else:
+                assert low_freq_wavelen != high_freq_wavelen
+                smooth = (old_context_len / wavelen - low_freq_factor) / (
+                    high_freq_factor - low_freq_factor
+                )
+                new_freqs.append((1 - smooth) * freq / scale_factor + smooth * freq)
+        return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device)
+    pass
+
+    def forward(self, x, position_ids=None, seq_len=None):
+        # x: [bs, num_attention_heads, seq_len, head_size]
+        if seq_len > self.current_rope_size:
+            self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
+
+        return (
+            self.cos_cached[:seq_len].to(dtype = x.dtype),
+            self.sin_cached[:seq_len].to(dtype = x.dtype),
+        )
+    pass
+
+    def get_cached(self, seq_len = None):
+        return self.cos_cached, self.sin_cached
+    pass
+
+    def extend_rope_embedding(self, x, seq_len):
+        if seq_len <= self.current_rope_size: return
+        # Iteratively grow by increments of 8192
+        self.current_rope_size = math.ceil(seq_len / 8192) * 8192
+        self._set_cos_sin_cache(self.current_rope_size, device = "cuda:0", dtype = x.dtype)
+    pass
+pass
+
+
+class LongRopeRotaryEmbedding(torch.nn.Module):
+    # For Phi 3.5 128K https://huggingface.co/microsoft/Phi-3.5-mini-instruct/blob/main/modeling_phi3.py
+    def __init__(self,
+        dim = None,
+        max_position_embeddings = 131072,
+        original_max_position_embeddings = 4096,
+        base = 10000,
+        short_factor = None,
+        long_factor  = None,
+        device = None,
+        config = None, # [TODO] Hack to pass in config - need to remove later
+    ):
+        super().__init__()
+        assert(short_factor is not None)
+        assert(long_factor  is not None)
+        assert(type(original_max_position_embeddings) is int)
+
+        if config is not None:
+            # [TODO] Hack to pass in config - need to remove later
+            base = config.rope_theta
+            partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
+            dim = int((config.hidden_size // config.num_attention_heads))
+            device = "cuda"
+            max_position_embeddings = config.max_position_embeddings
+        pass
+
+        self.dim = dim
+        self.max_position_embeddings = max_position_embeddings
+        self.original_max_position_embeddings = original_max_position_embeddings
+        self.base = base
+        # Dynamic RoPE we first set it to a max of 4 * 8192 tokens then we iteratively grow this
+        self.current_rope_size = min(original_max_position_embeddings, self.max_position_embeddings)
+
+        # Long RoPE similar to RoPE except short sequences have 1 cos / sin
+        # and long sequences have another cos / sin
+        inv_freq_shape = torch.arange(0, self.dim, 2, dtype=torch.int64, device="cpu").float() / self.dim
+        short_factor = torch.tensor(short_factor, device = "cpu", dtype = torch.float32)
+        long_factor  = torch.tensor(long_factor,  device = "cpu", dtype = torch.float32)
+        short_inv_freq = 1.0 / (short_factor * self.base**inv_freq_shape)
+        long_inv_freq  = 1.0 / (long_factor  * self.base**inv_freq_shape)
+
+        # Phi-3 Scale factor
+        scale = self.max_position_embeddings / self.original_max_position_embeddings
+        if scale <= 1.0:
+            scaling_factor = 1.0
+        else:
+            scaling_factor = math.sqrt(1 + math.log(scale) / math.log(self.original_max_position_embeddings))
+        pass
+        self.scaling_factor = scaling_factor
+
+        # Short and long inv_freq
+        self.register_buffer("short_inv_freq", short_inv_freq, persistent = False)
+        self.register_buffer("long_inv_freq",  long_inv_freq,  persistent = False)
+        # Build here to make `torch.jit.trace` work.
+        # self._set_cos_sin_cache(seq_len=self.current_rope_size, device=device, dtype=torch.get_default_dtype())
+
+        # Short sequences
+        dtype = torch.bfloat16 if is_bfloat16_supported() else torch.float16
+        t = torch.arange(original_max_position_embeddings, device=self.short_inv_freq.device, dtype=torch.int64).float()
+        freqs = torch.outer(t, self.short_inv_freq)
+        emb = torch.cat((freqs, freqs), dim=-1)
+        cos_cached = (emb.cos() * self.scaling_factor).to(dtype=dtype, device=device, non_blocking=True)
+        sin_cached = (emb.sin() * self.scaling_factor).to(dtype=dtype, device=device, non_blocking=True)
+        self.register_buffer("short_cos_cached", cos_cached, persistent=False)
+        self.register_buffer("short_sin_cached", sin_cached, persistent=False)
+    pass
+
+    def _set_cos_sin_cache(self, seq_len, device, dtype):
+        # Note: on the original Llama codebase, these tensors are created on the target device (and not on CPU) and
+        # in FP32. They are applied (multiplied) in FP32 as well.
+        self.current_rope_size = seq_len
+        
+        t = torch.arange(self.current_rope_size, device=self.long_inv_freq.device, dtype=torch.int64).float()
+        # Long sequences
+        freqs = torch.outer(t, self.long_inv_freq)
+        emb = torch.cat((freqs, freqs), dim=-1)
+        cos_cached = (emb.cos() * self.scaling_factor).to(dtype=dtype, device=device, non_blocking=True)
+        sin_cached = (emb.sin() * self.scaling_factor).to(dtype=dtype, device=device, non_blocking=True)
+        self.register_buffer("long_cos_cached", cos_cached, persistent=False)
+        self.register_buffer("long_sin_cached", sin_cached, persistent=False)
+    pass
+
+    def forward(self, x, position_ids=None, seq_len=None):
+        # x: [bs, num_attention_heads, seq_len, head_size]
+        if seq_len > self.current_rope_size:
+            self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
+
+        if seq_len < self.original_max_position_embeddings:
+            return (
+                self.short_cos_cached[:seq_len].to(dtype = x.dtype),
+                self.short_sin_cached[:seq_len].to(dtype = x.dtype),
+            )
+        else:
+            return (
+                self.long_cos_cached[:seq_len].to(dtype = x.dtype),
+                self.long_sin_cached[:seq_len].to(dtype = x.dtype),
+            )
+        pass
+    pass
+
+    def get_cached(self, seq_len = None):
+        if seq_len < self.original_max_position_embeddings:
+            return self.short_cos_cached, self.short_sin_cached
+        return self.long_cos_cached, self.long_sin_cached
+    pass
+
+    def extend_rope_embedding(self, x, seq_len):
+        if seq_len <= self.current_rope_size: return
+        # Iteratively grow by increments of 8192
+        self.current_rope_size = math.ceil(seq_len / 8192) * 8192
+        self._set_cos_sin_cache(self.current_rope_size, device = "cuda:0", dtype = x.dtype)
+    pass
+pass
+
+
+def _wrap_fast_inference(generate, device_type, dtype, model):
+    # Wraps inference with bfloat16 / float16
+    @torch.inference_mode
+    def _fast_generate(*args, **kwargs):
+
+        # Set a flag for generation!
+        internal_model = model
+        while hasattr(internal_model, "model"):
+            internal_model._flag_for_generation = True
+            internal_model = internal_model.model
+        pass
+        internal_model._flag_for_generation = True
+
+        # Must patch accelerate for Xformers
+        if accelerate_new_send_to_device is not None:
+            import accelerate.utils.operations
+            accelerate.utils.operations.send_to_device = accelerate_new_send_to_device
+        pass
+
+        # For newer HF
+        kwargs["cache_implementation"] = "dynamic"
+        # For num_logits_to_keep
+        kwargs["num_logits_to_keep"] = 1
+
+        # Remove token_type_ids
+        kwargs.pop("token_type_ids", None)
+
+        # Check pad_token
+        model_eos_token_id = getattr(model.config, "eos_token_id", None)
+        if model_eos_token_id is not None and hasattr(model_eos_token_id, "__iter__"):
+            model_eos_token_id = model_eos_token_id[0]
+
+        kwargs["pad_token_id"] = kwargs.pop("pad_token_id", model_eos_token_id)
+
+        # Set pad token
+        # old_pad_token_id = getattr(model.config, "pad_token_id", None)
+        # old_eos_token_id = getattr(model.config, "eos_token_id", None)
+        # model.config.pad_token_id = old_eos_token_id
+
+        # Autocasted
+        with torch.autocast(device_type = device_type, dtype = dtype):
+            output = generate(*args, **kwargs)
+        pass
+
+        # Revert
+        # model.config.pad_token_id = old_pad_token_id
+
+        # Unset a flag for generation!
+        internal_model = model
+        while hasattr(internal_model, "model"):
+            if hasattr(internal_model, "_flag_for_generation"): del internal_model._flag_for_generation
+            internal_model = internal_model.model
+        pass
+        if hasattr(internal_model, "_flag_for_generation"): del internal_model._flag_for_generation
+
+        # Return accelerate back
+        if accelerate_new_send_to_device is not None:
+            accelerate.utils.operations.send_to_device = accelerate_old_send_to_device
+        pass
+
+        return output
+    pass
+    return _fast_generate
+pass
+
+
+class FastLlamaModel:
+
+    @staticmethod
+    def pre_patch():
+        init_name, function = patch_llama_rope_scaling(
+            model_name           = "llama",
+            rope_module          = LlamaRotaryEmbedding,
+            scaled_rope_module   = LlamaLinearScalingRotaryEmbedding,
+            extended_rope_module = LlamaExtendedRotaryEmbedding,
+            attention_module     = LlamaAttention,
+            longrope_module      = LongRopeRotaryEmbedding,
+        )
+        if init_name is not None:
+            exec(function, globals())
+            LlamaAttention.__init__  = eval(init_name)
+        pass
+        LlamaAttention      .forward = LlamaAttention_fast_forward
+        LlamaSdpaAttention  .forward = LlamaAttention_fast_forward
+        LlamaFlashAttention2.forward = LlamaAttention_fast_forward
+        LlamaDecoderLayer   .forward = LlamaDecoderLayer_fast_forward
+        LlamaModel          .forward = LlamaModel_fast_forward
+        LlamaForCausalLM    .forward = CausalLM_fast_forward(LlamaModel_fast_forward_inference)
+        PeftModelForCausalLM.forward = PeftModelForCausalLM_fast_forward
+        fix_prepare_inputs_for_generation(LlamaForCausalLM)
+
+        # Solves https://github.com/unslothai/unsloth/issues/168
+        # Static KV Cache was introduced in 4.38.0, causing training to be much slower.
+        # Inferene can now be CUDAGraphed, but we shall retain the old rotary embeddings.
+        # https://github.com/huggingface/transformers/pull/27931
+        # https://github.com/huggingface/transformers/blob/v4.37.2/src/transformers/models/llama/modeling_llama.py
+        import transformers.models.llama.modeling_llama
+        transformers.models.llama.modeling_llama.LlamaRotaryEmbedding = LlamaRotaryEmbedding
+        transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding = LlamaLinearScalingRotaryEmbedding
+        return
+    pass
+
+
+    @staticmethod
+    def from_pretrained(
+        model_name        = "unsloth/llama-3-8b-bnb-4bit",
+        max_seq_length    = None,
+        dtype             = None,
+        load_in_4bit      = True,
+        token             = None,
+        device_map        = "sequential",
+        rope_scaling      = None,
+        fix_tokenizer     = True,
+        model_patcher     = None,
+        tokenizer_name    = None,
+        trust_remote_code = False,
+        **kwargs,
+    ):
+        if trust_remote_code:
+            print(
+                "Unsloth: WARNING `trust_remote_code` is True.\n"\
+                "Are you certain you want to do remote code execution?"
+            )
+        pass
+        if token is None: token = get_token()
+        if model_patcher is None: model_patcher = FastLlamaModel
+        SUPPORTS_BFLOAT16 = is_bfloat16_supported()
+        gpu_stats = torch.cuda.get_device_properties(0)
+        max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)
+
+        statistics = \
+           f"==((====))==  Unsloth {__version__}: Fast {model_patcher.__name__[4:-5]} patching. Transformers = {transformers_version}.\n"\
+           f"   \\\   /|    GPU: {gpu_stats.name}. Max memory: {max_memory} GB. Platform = {platform_system}.\n"\
+           f"O^O/ \_/ \\    Pytorch: {torch.__version__}. CUDA = {gpu_stats.major}.{gpu_stats.minor}. CUDA Toolkit = {torch.version.cuda}.\n"\
+           f"\        /    Bfloat16 = {str(SUPPORTS_BFLOAT16).upper()}. FA [Xformers = {xformers_version}. FA2 = {HAS_FLASH_ATTENTION}]\n"\
+           f' "-____-"     Free Apache license: http://github.com/unslothai/unsloth'
+        print(statistics)
+
+        # Warn about fast transfers
+        old_hf_transfer = os.environ.get("HF_HUB_ENABLE_HF_TRANSFER", "0")
+        if os.environ.get("HF_HUB_ENABLE_HF_TRANSFER", "0") == "1":
+            print("Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!")
+        pass
+        # Return old flag
+        os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = old_hf_transfer
+
+        model_patcher.pre_patch()
+        get_statistics() # For debugging - we use a download counter to see if environments are not breaking 
+
+        if dtype is None:
+            dtype = torch.float16 if not SUPPORTS_BFLOAT16 else torch.bfloat16
+        elif dtype == torch.bfloat16 and not SUPPORTS_BFLOAT16:
+            logger.warning_once("Device does not support bfloat16. Will change to float16.")
+            dtype = torch.float16
+
+        assert(dtype == torch.float16 or dtype == torch.bfloat16 or dtype == torch.float32)
+
+        # RoPE Scaling
+        model_config = AutoConfig.from_pretrained(model_name, token = token)
+        model_max_seq_length = model_config.max_position_embeddings
+
+        # Check if RoPE Scaling is even allowed
+        model_function = MODEL_FOR_CAUSAL_LM_MAPPING[model_config.__class__]
+        has_rope_scaling = False
+        try:
+            with open(inspect.getfile(model_function), "r") as file:
+                has_rope_scaling = "self.config.rope_scaling" in file.read()
+        except: pass
+        has_rope_scaling = True
+
+        # If max_seq_length is not specified, use maximum fron config
+        if max_seq_length is None:
+            max_seq_length = model_max_seq_length
+        pass
+
+        if (rope_scaling is None) and (max_seq_length > model_max_seq_length):
+
+            rope_scaling = max_seq_length / model_max_seq_length
+
+            logger.warning_once(
+                f"Unsloth: {model_name} can only handle sequence lengths of at most "\
+                f"{model_max_seq_length}.\nBut with kaiokendev's RoPE scaling of "\
+                f"{round(rope_scaling, 3)}, it can be magically be extended to "\
+                f"{max_seq_length}!"
+            )
+
+            # Warn RoPE scaling isn't allowed
+            if not has_rope_scaling:
+                raise RuntimeError(
+                    "However, {model_name} doesn't support RoPE Scaling!\n"\
+                    "Please file a feature request at https://github.com/unslothai/unsloth."
+                )
+            pass
+
+            rope_scaling = {"type": "linear", "factor": rope_scaling,}
+
+            # Add to kwargs
+            kwargs["rope_scaling"] = rope_scaling
+        pass
+        # We currently only support NVIDIA GPUs - AMD / Intel is a work in progress!
+        pre_check = check_nvidia()
+
+        bnb_config = None
+        if load_in_4bit:
+            bnb_config = BitsAndBytesConfig(
+                load_in_4bit              = True,
+                bnb_4bit_use_double_quant = True,
+                bnb_4bit_quant_type       = "nf4",
+                bnb_4bit_compute_dtype    = dtype,
+            )
+        pass
+
+        # https://huggingface.co/togethercomputer/LLaMA-2-7B-32K/discussions/12
+        # RoPE Scaling's max_position_embeddings must be updated
+        max_position_embeddings = max(max_seq_length, model_max_seq_length)
+        kwargs.pop("attn_implementation", None); # No need since we auto call it
+
+        # Cannot be None, since HF now checks for the config
+        if load_in_4bit: kwargs["quantization_config"] = bnb_config
+        
+        model = AutoModelForCausalLM.from_pretrained(
+            model_name,
+            device_map              = device_map,
+            torch_dtype             = dtype,
+            # quantization_config     = bnb_config,
+            token                   = token,
+            max_position_embeddings = max_position_embeddings,
+            trust_remote_code       = trust_remote_code,
+            attn_implementation     = "eager",
+            **kwargs,
+        )
+        # Return old flag
+        os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = old_hf_transfer
+        # We currently only support NVIDIA GPUs - AMD / Intel is a work in progress!
+        post_check = check_nvidia()
+
+        # Counteract saved tokenizers
+        tokenizer_name = model_name if tokenizer_name is None else tokenizer_name
+        tokenizer = load_correct_tokenizer(
+            tokenizer_name    = tokenizer_name,
+            model_max_length  = max_position_embeddings,
+            padding_side      = "right",
+            token             = token,
+            trust_remote_code = trust_remote_code,
+            fix_tokenizer     = fix_tokenizer,
+        )
+
+        model, tokenizer = patch_tokenizer(model, tokenizer)
+        model = model_patcher.post_patch(model)
+
+        # Patch up QKV / O and MLP
+        for idx, layer in enumerate(model.model.layers):
+            layer.self_attn.apply_qkv = original_apply_qkv
+            layer.self_attn.apply_o   = original_apply_o
+        pass
+
+        # Patch Trainer
+        from transformers.trainer import Trainer
+        try:
+            if Trainer._inner_training_loop.__name__ != "_fast_inner_training_loop":
+                inner_training_loop = inspect.getsource(Trainer._inner_training_loop)
+                Trainer._original_training_loop = inner_training_loop
+            else:
+                inner_training_loop = Trainer._original_training_loop
+        except:
+            raise RuntimeError('Unsloth currently does not support multi GPU setups - but we are working on it!')
+        pass
+
+        if ((post_check - pre_check) >= 1).sum() > 1:
+            raise RuntimeError('Unsloth currently does not support multi GPU setups - but we are working on it!')
+
+        import transformers.trainer
+        items_in_trainer = dir(transformers.trainer)
+        good_items = []
+        for item in items_in_trainer:
+            # TODO: Support Deepspeed
+            if item.startswith(("deepspeed", "xm", "met", "smp")): continue
+            if item in inner_training_loop: good_items.append(item)
+        pass
+        exec("from transformers.trainer import (" + ", ".join(x for x in good_items) + ")", globals())
+
+        start = re.search('logger\.info\([\"\'].+?Running training', inner_training_loop).span(0)[0]
+        end = inner_training_loop.find("\n\n", start)
+        original_debug = inner_training_loop[start:end]
+        spaces = re.search('\n([\s\t]{1,})', original_debug).group(0)[1:]
+        front_spaces = re.match('([\s\t]{1,})', inner_training_loop).group(0)
+
+        debug_info = """debug_info = \\
+        f"==((====))==  Unsloth - 2x faster free finetuning | Num GPUs = {args.world_size}\\n"\\
+        f"   \\\\\\   /|    Num examples = {num_examples:,} | Num Epochs = {num_train_epochs:,}\\n"\\
+        f"O^O/ \\_/ \\    Batch size per device = {self._train_batch_size:,} | Gradient Accumulation steps = {args.gradient_accumulation_steps}\\n"\\
+        f"\\        /    Total batch size = {total_train_batch_size:,} | Total steps = {max_steps:,}\\n"\\
+        f' "-____-"     Number of trainable parameters = {get_model_param_count(model, trainable_only=True):,}'
+        logger.warning(debug_info)
+        import subprocess, re, gc, numpy as np
+        a = np.array([0,])
+        try:
+            a = subprocess.check_output('nvidia-smi --query-gpu=memory.used --format=csv', shell = True)
+            a = re.findall(rb'([\\d]{1,})[\\s]{1,}M', a)
+            a = np.array([int(x.decode('utf-8'))/1024 for x in a])
+        except:
+            if not torch.cuda.is_available():
+                raise RuntimeError('Unsloth: We do not support AMD / Intel machines yet - it is a work in progress!')
+        if ((a - PRE_CHECK) >= 1).sum() > 1:
+            raise RuntimeError('Unsloth currently does not support multi GPU setups - but we are working on it!')
+        for _ in range(3):
+            gc.collect()
+            torch.cuda.empty_cache()"""
+
+        debug_info = debug_info.split('\n')
+        debug_info = "\n".join([debug_info[0]] + [spaces + x[8:] for x in debug_info[1:]])
+        inner_training_loop = inner_training_loop.replace(original_debug, debug_info)
+
+        debug_info = """n_total_devices = total_train_batch_size // \\
+            args.gradient_accumulation_steps // self._train_batch_size
+        if n_total_devices > 1:
+            logger.warning_once('Unsloth currently does not support multi GPU setups - but we are working on it!')
+        debug_info ="""
+        debug_info = debug_info.split('\n')
+        debug_info = "\n".join([debug_info[0]] + [spaces + x[8:] for x in debug_info[1:]])
+        inner_training_loop = inner_training_loop.replace("debug_info =", debug_info, 1)
+
+        front_spaces = re.match(r"[\t\s]{1,}", inner_training_loop).group(0)
+        inner_training_loop = re.sub(r"^" + front_spaces, "", inner_training_loop, flags = re.MULTILINE)
+        inner_training_loop = inner_training_loop.replace(
+            "train_dataloader = tpu_spmd_dataloader(train_dataloader)",
+            "raise RuntimeError('Unsloth: TPUs are not yet supported!')"
+        )
+        inner_training_loop = inner_training_loop.replace(
+            "self.accelerator.free_memory()",
+            "self.accelerator.free_memory()\n" + \
+            front_spaces + "if self.is_deepspeed_enabled:"\
+            "raise RuntimeError('Unsloth: Deepspeed is not yet supported!')\n", 1,
+        )
+
+        check_batches = """train_dataloader = self.get_train_dataloader()
+        ga  = args.gradient_accumulation_steps
+        bsz = self._train_batch_size
+        total_batches = bsz * ga * args.world_size
+        n_total_devices = total_batches // ga // bsz
+        if n_total_devices > 1:
+            logger.warning_once('Unsloth currently does not support multi GPU setups - but we are working on it!')
+            divisor = n_total_devices / 1
+            bsz = self._train_batch_size = max(int(bsz / divisor), 1)
+            if total_batches // ga // bsz > 1:
+                divisor = n_total_devices / 1
+                ga = args.gradient_accumulation_steps = max(int(ga / divisor), 1)"""
+        check_batches = check_batches.split('\n')
+        check_batches = "\n".join([check_batches[0]] + [front_spaces + x[8:] for x in check_batches[1:]])
+        inner_training_loop = inner_training_loop.replace(
+            "train_dataloader = self.get_train_dataloader()",
+            check_batches, 1,
+        )
+        inner_training_loop = inner_training_loop.replace(
+            "_inner_training_loop",
+            "_fast_inner_training_loop", 1,
+        )
+        exec(inner_training_loop, globals())
+
+        Trainer._inner_training_loop = _fast_inner_training_loop
+        inner_training_loop = inner_training_loop.replace(
+            "is_torch_tpu_available()",
+            "False",
+        )
+        if "n_total_devices >" not in inner_training_loop:
+            raise RuntimeError('Unsloth currently does not support multi GPU setups - but we are working on it!')
+        pass
+        inner_training_loop = inner_training_loop.replace(
+            "is_sagemaker_mp_enabled()",
+            "False",
+        )
+        exec(inner_training_loop, globals())
+        Trainer._inner_training_loop = _fast_inner_training_loop
+
+        # Save max_seq_length
+        model.max_seq_length = max_position_embeddings
+        internal_model = model
+        while hasattr(internal_model, "model"):
+            internal_model.max_seq_length = max_position_embeddings
+            internal_model = internal_model.model
+        pass
+        internal_model.max_seq_length = max_position_embeddings
+
+        # We check the tokenizer first for errors
+        if fix_tokenizer:
+            tokenizer = check_tokenizer(
+                model            = model,
+                tokenizer        = tokenizer,
+                model_name       = model_name,
+                model_max_length = max_position_embeddings,
+                padding_side     = "right",
+                token            = token,
+            )
+        pass
+        patch_saving_functions(tokenizer)
+
+        # Fix up config for transformers uploading PEFT
+        # Not necessary anymore since we require transformers>=4.37!
+        if False:
+            name = model.config._name_or_path
+            if name.startswith("unsloth/") and name.endswith("-bnb-4bit"):
+                name = name[:len(name) - len("-bnb-4bit")]
+                model.config.update({"_name_or_path" : name})
+            pass
+        pass
+
+        # Log Unsloth version for future fastpaths for inference
+        model.config.update({"unsloth_version" : __version__})
+
+        # Add save modules
+        patch_saving_functions(model)
+        Trainer._inner_training_loop = _fast_inner_training_loop
+
+        # Save tokenizer for inference purposes
+        tokenizer.padding_side = "left" # Force inference
+        internal_model = model
+        while hasattr(internal_model, "model"):
+            internal_model._saved_temp_tokenizer = tokenizer
+            internal_model = internal_model.model
+        pass
+        internal_model._saved_temp_tokenizer = tokenizer
+
+        # Also fix torch_dtype
+        internal_model = model
+        while hasattr(internal_model, "model"):
+            if hasattr(internal_model, "config"):
+                if   internal_model.config.torch_dtype ==  "float32":
+                    internal_model.config.torch_dtype = torch.float32
+                elif internal_model.config.torch_dtype == "bfloat16":
+                    internal_model.config.torch_dtype = torch.bfloat16
+                elif internal_model.config.torch_dtype ==  "float16":
+                    internal_model.config.torch_dtype = torch.float16
+                pass
+            pass
+            internal_model = internal_model.model
+        pass
+        if hasattr(internal_model, "config"):
+            if   internal_model.config.torch_dtype ==  "float32":
+                internal_model.config.torch_dtype = torch.float32
+            elif internal_model.config.torch_dtype == "bfloat16":
+                internal_model.config.torch_dtype = torch.bfloat16
+            elif internal_model.config.torch_dtype ==  "float16":
+                internal_model.config.torch_dtype = torch.float16
+            pass
+        pass
+        
+        return model, tokenizer
+    pass
+
+
+    @staticmethod
+    def post_patch(model):
+        # Patch model
+        layers = model.model.layers
+
+        # Torch.compile fails on embedding matrix??
+        # Workaround randomnly fixes it for torch versions < 2.
+        model.set_input_embeddings(torch.nn.Embedding.from_pretrained(model.get_input_embeddings().weight))
+        model.config.update({"unsloth_version" : __version__})
+
+        # We also do this for the lm_head
+        lm_head = torch.nn.Linear(1, 1, bias = None)
+        del lm_head.weight
+        lm_head.weight = model.get_output_embeddings().weight
+        lm_head.in_features  = lm_head.weight.shape[1]
+        lm_head.out_features = lm_head.weight.shape[0]
+        model.lm_head = lm_head
+        
+        # Also patch all dtypes - BnB seems to not allocate the correct type?
+        # BnB default dtype seems to be float16!
+        correct_dtype = lm_head.weight.dtype
+
+        for name, module in model.named_modules():
+            if isinstance(module, (Bnb_Linear4bit, Peft_Linear4bit)):
+                weight = module.weight
+                quant_state = weight.quant_state
+
+                if type(quant_state) is list:
+                    # BnB seems to have float16 as default!
+                    module.weight.quant_state[2] = correct_dtype # Cast to correct dtype
+                else:
+                    # https://github.com/TimDettmers/bitsandbytes/pull/763/files
+                    quant_state.dtype = correct_dtype
+                pass
+            pass
+            # Downcast RoPE embedding to correct data type
+            if (name.endswith("rotary_emb") or hasattr(module, "cos_cached")):
+
+                if hasattr(module, "cos_cached") and \
+                    (module.cos_cached.dtype != correct_dtype):
+
+                    module.cos_cached = module.cos_cached.to(correct_dtype)
+                    module.sin_cached = module.sin_cached.to(correct_dtype)
+
+                elif hasattr(module, "short_cos_cached") and \
+                    (module.short_cos_cached.dtype != correct_dtype):
+                    
+                    module.short_cos_cached = module.short_cos_cached.to(correct_dtype)
+                    module.short_sin_cached = module.short_sin_cached.to(correct_dtype)
+                pass
+            pass
+        pass
+
+        # Clear deleted GPU items
+        for _ in range(3):
+            gc.collect()
+            torch.cuda.empty_cache()
+        return model
+    pass
+
+
+    @staticmethod
+    def get_peft_model(
+        model,
+        r                   = 16,
+        target_modules      = ["q_proj", "k_proj", "v_proj", "o_proj",
+                               "gate_proj", "up_proj", "down_proj"],
+        lora_alpha          = 16,
+        lora_dropout        = 0,
+        bias                = "none",
+        layers_to_transform = None,
+        layers_pattern      = None,
+        use_gradient_checkpointing = True,
+        random_state        = 3407,
+        max_seq_length      = 2048, # not used anymore
+        use_rslora          = False,
+        modules_to_save     = None,
+        init_lora_weights   = True,
+        loftq_config        = {},
+        temporary_location  = "_unsloth_temporary_saved_buffers",
+        **kwargs,
+    ):
+        transformers_set_seed(random_state)
+
+        if isinstance(model, PeftModelForCausalLM):
+            # Check if exactly the same and then pass through!
+            assert(hasattr(model, "peft_config"))
+
+            peft_config = model.peft_config["default"].to_dict()
+            check_parameters = [
+                "r", "lora_alpha", "lora_dropout",
+                "bias", "layers_to_transform", "layers_pattern",
+                "use_rslora", "init_lora_weights",
+            ]
+            check_all = True
+            for param in check_parameters:
+                check_all = check_all and (peft_config[param] == eval(param))
+            pass
+
+            # Check save_modules
+            old_target_modules = list(peft_config["target_modules"])
+            modules_to_save = peft_config["modules_to_save"]
+            if modules_to_save is None: modules_to_save = {}
+            modules_to_save = list(modules_to_save)
+            old_target_modules += modules_to_save
+
+            # Combine all
+            new_target_modules = list(target_modules) + \
+                list(modules_to_save if modules_to_save is not None else [])
+
+            # Now check!
+            new_target_modules = set(new_target_modules)
+            check_all = check_all and (
+                len(set(old_target_modules) ^ new_target_modules) == 0
+            )
+
+            check_all = check_all and (
+                (loftq_config == {} or loftq_config is None) and \
+                (peft_config["loftq_config"] == {} or peft_config["loftq_config"] is None)
+            )
+
+            if check_all:
+                # Simply pass through!
+                logger.warning(
+                    "Unsloth: Already have LoRA adapters! We shall skip this step."
+                )
+
+                # Offload!
+                # [TODO] First offload lm_head and embed_tokens to CPU (should be disk!!)
+                if "embed_tokens" in new_target_modules:
+                    print("Unsloth: Casting embed_tokens to float32")
+
+                    model.model.model.embed_tokens.modules_to_save.default\
+                        .to(device = "cuda:0", dtype = torch.float32, non_blocking = True)
+                    model.model.model.embed_tokens.modules_to_save.default.requires_grad_(True)
+
+                    # [TODO] Move old embed_tokens to CPU - should be disk!
+                    model.model.model.embed_tokens.original_module\
+                        .to(device = "cpu", non_blocking = True)
+                    model.model.model.embed_tokens.original_module.requires_grad_(False)
+                pass
+
+                if "lm_head" in new_target_modules:
+                    print("Unsloth: Casting lm_head to float32")
+
+                    model.model.lm_head.modules_to_save.default\
+                        .to(device = "cuda:0", dtype = torch.float32, non_blocking = True)
+                    model.model.lm_head.modules_to_save.default.requires_grad_(True)
+
+                    # [TODO] Move old lm_head to CPU - should be disk!
+                    model.model.lm_head.original_module\
+                        .to(device = "cpu", non_blocking = True)
+                    model.model.lm_head.original_module.requires_grad_(False)
+                pass
+
+                return model
+            else:
+                raise TypeError(
+                    "Unsloth: Your model already has LoRA adapters. Your new parameters are different."
+                )
+            pass
+        pass
+
+        if loftq_config is None: loftq_config = {}
+
+        signature = str(inspect.signature(LoraConfig))
+        SUPPORTS_LOFTQ  = "loftq_config" in signature
+        SUPPORTS_RSLORA = "use_rslora"   in signature
+        
+        assert(max_seq_length <= model.max_seq_length)
+
+        if lora_dropout != 0:
+            logger.warning_once(
+                f"Unsloth: Dropout = 0 is supported for fast patching. You are using dropout = {lora_dropout}.\n"\
+                f"Unsloth will patch all other layers, except LoRA matrices, causing a performance hit."
+            )
+        pass
+
+        if bias != "none":
+            logger.warning_once(
+                f"Unsloth: bias = `none` is supported for fast patching. You are using bias = {bias}.\n"\
+                f"Unsloth will patch all other layers, except LoRA matrices, causing a performance hit."
+            )
+        pass
+
+        if not (type(init_lora_weights) is bool or \
+            init_lora_weights == "gaussian" or init_lora_weights == "loftq"):
+            raise ValueError(
+                'Unsloth: `init_lora_weights` must be either [True, False, "gaussian", "loftq"].'
+            )
+        pass
+
+        if init_lora_weights == "loftq":
+
+            if not SUPPORTS_LOFTQ:
+                import peft
+                raise RuntimeError(
+                    f"Unsloth: Your PEFT version of {peft.__version__} does not support LoftQ init.\n"\
+                    "Please install PEFT 0.7.2 or higher.\n"\
+                    "You can also install from source: `pip install git+https://github.com/huggingface/peft.git"
+                )
+            pass
+
+            if loftq_config == {}:
+                from peft import LoftQConfig
+                logger.warning_once(
+                    f"Unsloth: init_lora_weights = `loftq` is set, but `loftq_config` is None.\n"\
+                    f"We shall use `loftq_config = LoftQConfig(loftq_bits = 4, loftq_iter = 1)`."
+                )
+                loftq_config = LoftQConfig(loftq_bits = 4, loftq_iter = 1)
+            pass
+            
+            if hasattr(model.config, "quantization_config"):
+                raise ValueError(
+                    "Unsloth: You are using `loftq` init, yet `load_in_4bit = True` was set.\n"\
+                    "Reload your model without any quantization by setting `load_in_4bit = False`."
+                )
+            pass
+        pass
+
+        assert(type(use_rslora) is bool)
+        if use_rslora:
+            if not SUPPORTS_RSLORA:
+                # We manually check for PEFT
+                import peft
+                raise RuntimeError(
+                    f"Unsloth: Your PEFT version of {peft.__version__} does not support `use_rslora`.\n"\
+                    "Please install PEFT 0.7.2 or higher.\n"\
+                    "You can also install from source: `pip install git+https://github.com/huggingface/peft.git"
+                )
+            pass
+        pass
+
+        accepted_modules = frozenset(("q_proj", "k_proj", "v_proj", "o_proj",
+                                      "gate_proj", "up_proj", "down_proj",),)
+        model.config.update({"unsloth_version" : __version__})
+
+        if type(modules_to_save) is tuple:
+            modules_to_save = list(modules_to_save)
+        pass
+
+        train_lm_head = False
+        train_embed_tokens = False
+        final_modules = []
+        for module in target_modules:
+            if module == "lm_head":
+                # logger.warning_once(
+                #     "Unsloth: `lm_head` should be placed in `modules_to_save` and not `target_modules`. "\
+                #     "Luckily, we shall do it for you!"
+                # )
+                train_lm_head = True
+                if modules_to_save is None: modules_to_save = ["lm_head"]
+                else: modules_to_save.append("lm_head")
+
+            elif module == "embed_tokens":
+                # logger.warning_once(
+                #     "Unsloth: `embed_tokens` should be placed in `modules_to_save` and not `target_modules`. "\
+                #     "Luckily, we shall do it for you!"
+                # )
+                train_embed_tokens = True
+                if modules_to_save is None: modules_to_save = ["embed_tokens"]
+                else: modules_to_save.append("embed_tokens")
+
+            else:
+                try:
+                    assert(module in accepted_modules)
+                    final_modules.append(module)
+                except AssertionError as e:
+                    final_modules.append(module)
+                    print(
+                        "Unsloth: You added custom modules, but Unsloth hasn't optimized for this.\n"\
+                        "Beware - your finetuning might be noticeably slower!"
+                    )
+                pass
+            pass
+        pass
+
+        # Check if we added new tokens!
+        if hasattr(model, "_need_to_train_embeddings"):
+            if not train_lm_head or not train_embed_tokens:
+                print(
+                    "Unsloth: You added new tokens but did not specify if you wanted to "\
+                    "train the lm_head and embed_tokens.\nWe must turn it on for you."
+                )
+                train_lm_head = True
+                train_embed_tokens = True
+
+                if modules_to_save is None: modules_to_save = ["embed_tokens"]
+                else: modules_to_save.append("embed_tokens")
+
+                if modules_to_save is None: modules_to_save = ["lm_head"]
+                else: modules_to_save.append("lm_head")
+            pass
+        pass
+
+        # Check for Llama-3
+        # if hasattr(model._saved_temp_tokenizer, "_using_llama3_template"):
+        #     if not train_embed_tokens and not train_lm_head:
+        #         raise RuntimeError("")
+
+        # First fix untrained tokens
+        # Wrong - can cause reserved tokens to pop out!!
+        # if train_embed_tokens or train_lm_head:
+        #     fix_untrained_tokens(model, eps = 1e-16)
+        # pass
+
+        # Check modules_to_save
+        if modules_to_save is not None:
+            for module in modules_to_save:
+                if module == "lm_head":
+                    train_lm_head = True
+                elif module == "embed_tokens":
+                    train_embed_tokens = True
+                else:
+                    raise TypeError(
+                        f"Unsloth: Module = {module} is not allowed. Only 'lm_head' and 'embed_tokens' is allowed."
+                    )
+            pass
+        pass
+        if isinstance(modules_to_save, (tuple, list)):
+            modules_to_save = list(set(modules_to_save))
+        pass
+
+        # Get LoRA
+        arguments = dict(
+            r                   = r,
+            lora_alpha          = lora_alpha,
+            target_modules      = final_modules,
+            lora_dropout        = lora_dropout,
+            bias                = bias,
+            task_type           = TaskType.CAUSAL_LM,
+            layers_to_transform = layers_to_transform,
+            init_lora_weights   = init_lora_weights,
+            loftq_config        = loftq_config,
+            use_rslora          = use_rslora,
+            modules_to_save     = modules_to_save,
+            **kwargs,
+        )
+        if not SUPPORTS_LOFTQ:  del arguments["loftq_config"]
+        if not SUPPORTS_RSLORA: del arguments["use_rslora"]
+
+        _saved_temp_tokenizer = model._saved_temp_tokenizer
+
+        lora_config = LoraConfig(**arguments)
+
+        # First offload lm_head and embed_tokens to disk
+        input_embeddings_device  = model. get_input_embeddings().weight.device
+        output_embeddings_device = model.get_output_embeddings().weight.device
+
+        if use_gradient_checkpointing == "unsloth":
+            if train_embed_tokens:
+                print("Unsloth: Offloading input_embeddings to disk to save VRAM")
+                offload_input_embeddings(model, temporary_location)
+            pass
+
+            # Remove old items to save VRAM
+            for _ in range(3):
+                gc.collect()
+                torch.cuda.empty_cache()
+            pass
+
+            if train_lm_head:
+                print("Unsloth: Offloading output_embeddings to disk to save VRAM")
+                offload_output_embeddings(model, temporary_location)
+            pass
+
+            # Remove old items to save VRAM
+            for _ in range(3):
+                gc.collect()
+                torch.cuda.empty_cache()
+            pass
+        pass
+
+        model = _get_peft_model(model, lora_config)
+
+        model._saved_temp_tokenizer = _saved_temp_tokenizer
+
+        model = FastLlamaModel.patch_peft_model(model, use_gradient_checkpointing)
+
+        # Now patch lm_head and embed_tokens
+        if train_embed_tokens:
+            print("Unsloth: Casting embed_tokens to float32")
+            assert(hasattr(model.model.model.embed_tokens, "modules_to_save"))
+            model.model.model.embed_tokens.modules_to_save.default\
+                .to(device = "cuda:0", dtype = torch.float32, non_blocking = True)
+            model.model.model.embed_tokens.modules_to_save.default.requires_grad_(True)
+        pass
+
+        if train_lm_head:
+            print("Unsloth: Casting lm_head to float32")
+            assert(hasattr(model.model.lm_head, "modules_to_save"))
+            model.model.lm_head.modules_to_save.default\
+                .to(device = "cuda:0", dtype = torch.float32, non_blocking = True)
+            model.model.lm_head.modules_to_save.default.requires_grad_(True)
+        pass
+
+        # Patch tokenizer to pad to the right
+        internal_model = model
+        while hasattr(internal_model, "model"):
+            if hasattr(internal_model, "_saved_temp_tokenizer"):
+                internal_model._saved_temp_tokenizer.padding_side = "right"
+            pass
+            internal_model = internal_model.model
+        pass
+        if hasattr(internal_model, "_saved_temp_tokenizer"):
+            internal_model._saved_temp_tokenizer.padding_side = "right"
+        pass
+
+        # Clear deleted GPU items
+        for _ in range(3):
+            gc.collect()
+            torch.cuda.empty_cache()
+        pass
+
+        return model
+    pass
+
+
+    @staticmethod
+    def patch_peft_model(
+        model,
+        use_gradient_checkpointing = True,
+    ):
+        if not isinstance(model, PeftModelForCausalLM):
+            raise TypeError(
+                "Unsloth: Your model needs to call `.get_peft_model` first!"
+            )
+        pass
+
+        # Get activation function
+        model_type = model.config.model_type
+
+        if   model_type == "llama":   apply_lora_mlp = apply_lora_mlp_swiglu
+        elif model_type == "mistral": apply_lora_mlp = apply_lora_mlp_swiglu
+        elif model_type == "qwen2":   apply_lora_mlp = apply_lora_mlp_swiglu
+        elif model_type == "gemma":   apply_lora_mlp = apply_lora_mlp_geglu_approx
+        elif model_type == "gemma2":  apply_lora_mlp = apply_lora_mlp_geglu_approx
+        elif model_type == "cohere":  apply_lora_mlp = apply_lora_mlp_swiglu
+        else:
+            raise NotImplementedError(f"Unsloth: {model_type} is not yet implemented!")
+        pass
+
+        model = prepare_model_for_kbit_training(
+            model,
+            use_gradient_checkpointing = use_gradient_checkpointing,
+            use_reentrant = True,
+        )
+
+        # Fix up config for transformers uploading PEFT
+        for active_adapter in model.peft_config.keys():
+            # Not necessary since we requires transformers >= 4.37
+            if False:
+                name = model.peft_config[active_adapter].base_model_name_or_path
+                if name.startswith("unsloth/") and name.endswith("-bnb-4bit"):
+                    name = name[:len(name) - len("-bnb-4bit")]
+                    model.peft_config[active_adapter].base_model_name_or_path = name
+                pass
+            # Add revision to enable future fast inference paths
+            # [TODO] Bugs out!see https://github.com/unslothai/unsloth/issues/492
+            # model.peft_config[active_adapter].revision = f"unsloth"
+        pass
+
+        from transformers.trainer import Trainer 
+        if Trainer._inner_training_loop.__name__ != "_fast_inner_training_loop":
+            raise RuntimeError(
+                'Unsloth currently does not work on multi GPU setups - sadly we are a 2 brother team so '\
+                'enabling it will require much more work, so we have to prioritize. Please understand!\n'\
+                'We do have a separate beta version, which you can contact us about!\n'\
+                'Thank you for your understanding and we appreciate it immensely!'
+            )
+        pass
+
+        # Fix loftq issues
+        # loftq_config must not = None, but rather {}
+        all_configs = model.peft_config
+        for key, current_config in all_configs.items():
+            if hasattr(current_config, "loftq_config") and current_config.loftq_config is None:
+                new_args = current_config.__dict__
+                new_args["loftq_config"] = {}
+                current_config = current_config.__class__(**new_args)
+                all_configs[key] = current_config
+            pass
+        pass
+
+        # Do patching
+        n_mlp = 0
+        n_qkv = 0
+        n_o   = 0
+        import types
+
+        active_adapter = model.active_adapters[0] if \
+            hasattr(model, "active_adapters") else model.active_adapter
+
+        # Get dropout and bias
+        lora_dropout = model.peft_config[active_adapter].lora_dropout
+        bias         = model.peft_config[active_adapter].bias
+
+        # We also do not inplace edit QKV for Cohere!
+        from functools import partial
+        _apply_lora_mlp = \
+            partial(apply_lora_mlp, inplace = False) \
+            if model_type == "cohere" else \
+            apply_lora_mlp
+        pass
+
+        if lora_dropout == 0 and bias == "none":
+            for idx, layer in enumerate(model.model.model.layers):
+
+                # MLP patching
+                gate_proj = layer.mlp.gate_proj
+                up_proj   = layer.mlp.  up_proj
+                down_proj = layer.mlp.down_proj
+
+                if  hasattr(gate_proj, "lora_A") and \
+                    hasattr(  up_proj, "lora_A") and \
+                    hasattr(down_proj, "lora_A") and \
+                    (getattr(gate_proj, "base_layer", gate_proj).bias is None) and \
+                    (getattr(  up_proj, "base_layer",   up_proj).bias is None) and \
+                    (getattr(down_proj, "base_layer", down_proj).bias is None) and \
+                    (len(getattr(gate_proj, "lora_magnitude_vector", []) or []) == 0) and \
+                    (len(getattr(  up_proj, "lora_magnitude_vector", []) or []) == 0) and \
+                    (len(getattr(down_proj, "lora_magnitude_vector", []) or []) == 0):
+
+                    # https://stackoverflow.com/questions/50599045/python-replacing-a-function-within-a-class-of-a-module
+                    layer.mlp.forward = types.MethodType(_apply_lora_mlp, layer.mlp)
+                    n_mlp += 1
+                else:
+                    logger.warning_once(
+                        "Not an error, but Unsloth cannot patch MLP layers with our manual autograd engine since either LoRA adapters\n"\
+                        "are not enabled or a bias term (like in Qwen) is used."
+                    )
+                pass
+
+                # QKV attention patching
+                q_proj = layer.self_attn.q_proj
+                k_proj = layer.self_attn.k_proj
+                v_proj = layer.self_attn.v_proj
+                if  hasattr(q_proj, "lora_A") and \
+                    hasattr(k_proj, "lora_A") and \
+                    hasattr(v_proj, "lora_A") and \
+                    (getattr(q_proj, "base_layer", q_proj).bias is None) and \
+                    (getattr(k_proj, "base_layer", k_proj).bias is None) and \
+                    (getattr(v_proj, "base_layer", v_proj).bias is None) and \
+                    (len(getattr(q_proj, "lora_magnitude_vector", []) or []) == 0) and \
+                    (len(getattr(k_proj, "lora_magnitude_vector", []) or []) == 0) and \
+                    (len(getattr(v_proj, "lora_magnitude_vector", []) or []) == 0):
+
+                    layer.self_attn.apply_qkv = apply_lora_qkv
+                    n_qkv += 1
+                else:
+                    if model_type != "qwen2":
+                        logger.warning_once(
+                            "Not an error, but Unsloth cannot patch Attention layers with our manual autograd engine since either LoRA adapters\n"\
+                            "are not enabled or a bias term (like in Qwen) is used."
+                        )
+                    pass
+                pass
+
+                # O attention patching
+                o_proj = layer.self_attn.o_proj
+                if hasattr(o_proj, "lora_A") and \
+                    (getattr(o_proj, "base_layer", o_proj).bias is None) and \
+                    (len(getattr(o_proj, "lora_magnitude_vector", []) or []) == 0):
+
+                    layer.self_attn.apply_o = apply_lora_o
+                    n_o += 1
+                else:
+                    logger.warning_once(
+                        "Not an error, but Unsloth cannot patch O projection layer with our manual autograd engine since either LoRA adapters\n"\
+                        "are not enabled or a bias term (like in Qwen) is used."
+                    )
+                pass
+            pass
+        pass
+
+        logger.warning_once(
+            f"Unsloth {__version__} patched {len(model.model.model.layers)} layers with "\
+            f"{n_qkv} QKV layers, {n_o} O layers and {n_mlp} MLP layers.",
+        )
+        patch_saving_functions(model)
+
+        # Patch cross entropy loss labels
+        # Fixes https://github.com/unslothai/unsloth/issues/10
+        max_seq_length = model.max_seq_length
+        extra_ignored_labels = torch.full((max_seq_length, 1), -100, device = "cuda:0")
+        model.model.extra_ignored_labels = extra_ignored_labels
+        internal_model = model
+        while hasattr(internal_model, "model"):
+            internal_model.max_seq_length = max_seq_length
+            internal_model = internal_model.model
+        pass
+        internal_model.max_seq_length = max_seq_length        
+
+        # Patch tokenizer to pad to the right
+        internal_model = model
+        while hasattr(internal_model, "model"):
+            if hasattr(internal_model, "_saved_temp_tokenizer"):
+                internal_model._saved_temp_tokenizer.padding_side = "right"
+            pass
+            internal_model = internal_model.model
+        pass
+        if hasattr(internal_model, "_saved_temp_tokenizer"):
+            internal_model._saved_temp_tokenizer.padding_side = "right"
+        pass
+
+        # Clear deleted GPU items
+        for _ in range(3):
+            gc.collect()
+            torch.cuda.empty_cache()
+        pass
+        return model
+    pass
+
+
+    @staticmethod
+    def for_inference(model):
+        # if model.config.model_type == "qwen2":
+        #     FastLlamaModel.for_training(model)
+        #     return
+        # pass
+
+        internal_model = model
+        internal_model.gradient_checkpointing = False
+        internal_model.training = False
+
+        while hasattr(internal_model, "model"):
+            internal_model = internal_model.model
+            internal_model.gradient_checkpointing = False
+            internal_model.training = False
+        pass
+        if hasattr(internal_model, "training"):
+            internal_model.training = False
+        pass
+
+        # Also check if lm_head / embeddings are trained
+        internal_model = model
+        while not hasattr(internal_model, "lm_head"):
+            internal_model = internal_model.model
+        pass
+        lm_head = internal_model.lm_head.weight
+        device_type = lm_head.device.type
+        dtype = model.config.torch_dtype
+        
+        if type(dtype) is str:
+            if   dtype ==  "float16": dtype = torch.float16
+            elif dtype == "bfloat16": dtype = torch.bfloat16
+        pass
+
+        # Wrap model.generate
+        if model.generate.__name__ != "_fast_generate":
+            model._unwrapped_old_generate = model.generate
+            model.generate = _wrap_fast_inference(model.generate, device_type, dtype, model)
+        pass
+        
+        # Patch tokenizer to pad to the left
+        internal_model = model
+        while hasattr(internal_model, "model"):
+            if hasattr(internal_model, "_saved_temp_tokenizer"):
+                internal_model._saved_temp_tokenizer.padding_side = "left"
+            pass
+            internal_model = internal_model.model
+        pass
+        if hasattr(internal_model, "_saved_temp_tokenizer"):
+            internal_model._saved_temp_tokenizer.padding_side = "left"
+        pass
+
+        # Also disable training for embeddings for NEFTune
+        if hasattr(model, "get_input_embeddings"):
+            embeddings = model.get_input_embeddings()
+            if hasattr(embeddings, "training"): embeddings.training = False
+        pass
+        if hasattr(model, "get_output_embeddings"):
+            embeddings = model.get_output_embeddings()
+            if hasattr(embeddings, "training"): embeddings.training = False
+        pass
+
+        return model
+    pass
+
+
+    @staticmethod
+    def for_training(model, use_gradient_checkpointing = True):
+        internal_model = model
+        internal_model.gradient_checkpointing = use_gradient_checkpointing
+        internal_model.training = True
+
+        # Delete all fast inference loras
+        for param in model.parameters():
+            if hasattr(param, "_fast_lora"):
+                del param._fast_lora
+        pass
+
+        while hasattr(internal_model, "model"):
+            internal_model = internal_model.model
+            internal_model.gradient_checkpointing = use_gradient_checkpointing
+            internal_model.training = True
+        pass
+        if hasattr(internal_model, "training"):
+            internal_model.training = True
+        pass
+
+        # Also revert model.generate
+        if hasattr(model, "_unwrapped_old_generate"):
+            model.generate = model._unwrapped_old_generate
+            del model._unwrapped_old_generate
+        pass
+
+        # Patch tokenizer to pad to the right
+        internal_model = model
+        while hasattr(internal_model, "model"):
+            if hasattr(internal_model, "_saved_temp_tokenizer"):
+                internal_model._saved_temp_tokenizer.padding_side = "right"
+            pass
+            internal_model = internal_model.model
+        pass
+        if hasattr(internal_model, "_saved_temp_tokenizer"):
+            internal_model._saved_temp_tokenizer.padding_side = "right"
+        pass
+
+        # Also re-enable training for embeddings for NEFTune
+        if hasattr(model, "get_input_embeddings"):
+            embeddings = model.get_input_embeddings()
+            if hasattr(embeddings, "training"): embeddings.training = True
+        pass
+        if hasattr(model, "get_output_embeddings"):
+            embeddings = model.get_output_embeddings()
+            if hasattr(embeddings, "training"): embeddings.training = True
+        pass
+
+        return model
+    pass
+pass
+