Update llama_xformers_attention.py
Browse files
llama_xformers_attention.py
CHANGED
@@ -3,7 +3,11 @@ import torch.nn as nn
|
|
3 |
|
4 |
from typing import Optional, Tuple
|
5 |
|
6 |
-
from transformers.models.llama.modeling_llama import LlamaAttention
|
|
|
|
|
|
|
|
|
7 |
|
8 |
class LlamaXFormersAttention(LlamaAttention):
|
9 |
def forward(
|
|
|
3 |
|
4 |
from typing import Optional, Tuple
|
5 |
|
6 |
+
from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb
|
7 |
+
|
8 |
+
from xformers.ops.fmha import (
|
9 |
+
memory_efficient_attention,
|
10 |
+
)
|
11 |
|
12 |
class LlamaXFormersAttention(LlamaAttention):
|
13 |
def forward(
|