File size: 531 Bytes
0f8b77a 3226856 0377af8 b90a036 0f8b77a 2d1091a a5c24a9 0f8b77a b90a036 a5c24a9 b90a036 0377af8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 |
from transformers.integrations import TransformersPlugin, replace_target_class
from .llama_xformers_attention import LlamaXFormersAttention
class LlamaXFormersPlugin(TransformersPlugin):
def __init__(self, config):
pass
def process_model_pre_init(self, model):
model_config = model.config
replace_target_class(model, LlamaXFormersAttention, "LlamaAttention", init_kwargs={"config": model_config})
return model
def process_model_post_init(self, model):
return model
|