|
QLoRA+百万数据对baichun-7b模型进行高效指令微调 |
|
|
|
更多详情请查看Github项目: [Firefly(流萤): 中文对话式大语言模型(全量微调+QLoRA)](https://github.com/yangjianxin1/Firefly) |
|
|
|
单轮对话脚本: |
|
|
|
```python |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
import torch |
|
model_name = 'YeungNLP/firefly-baichuan-7b-qlora-sft-merge' |
|
max_new_tokens = 500 |
|
top_p = 0.9 |
|
temperature = 0.35 |
|
repetition_penalty = 1.0 |
|
device = 'cuda' |
|
input_pattern = '<s>{}</s>' |
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_name, |
|
trust_remote_code=True, |
|
low_cpu_mem_usage=True, |
|
torch_dtype=torch.float16, |
|
device_map='auto' |
|
) |
|
model.eval() |
|
model = model.to(device) |
|
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) |
|
text = input('User:') |
|
while True: |
|
text = input_pattern.format(text) |
|
input_ids = tokenizer(text, return_tensors="pt").input_ids |
|
input_ids = input_ids.to(device) |
|
outputs = model.generate( |
|
input_ids=input_ids, max_new_tokens=max_new_tokens, do_sample=True, |
|
top_p=top_p, temperature=temperature, repetition_penalty=repetition_penalty, |
|
eos_token_id=tokenizer.eos_token_id |
|
) |
|
rets = tokenizer.batch_decode(outputs) |
|
output = rets[0].strip().replace(text, "").replace('</s>', "") |
|
print("Firefly:{}".format(output)) |
|
text = input('User:') |
|
``` |
|
|
|
|
|
多轮对话脚本: |
|
|
|
```python |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
import torch |
|
device = 'cuda' |
|
model_name = 'YeungNLP/firefly-baichuan-7b1-qlora-sft-merge' |
|
max_new_tokens = 500 |
|
top_p = 0.9 |
|
temperature = 0.35 |
|
repetition_penalty = 1.0 |
|
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) |
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_name, |
|
trust_remote_code=True, |
|
low_cpu_mem_usage=True, |
|
torch_dtype=torch.float16, |
|
device_map='auto' |
|
) |
|
model.eval() |
|
model = model.to(device) |
|
# 记录所有历史记录 |
|
history_token_ids = tokenizer('<s>', return_tensors="pt").input_ids |
|
# 输入模型的最大长度 |
|
history_max_len = 1000 |
|
user_input = input('User:') |
|
while True: |
|
user_input = '{}</s>'.format(user_input) |
|
user_input_ids = tokenizer(user_input, return_tensors="pt").input_ids |
|
history_token_ids = torch.concat((history_token_ids, user_input_ids), dim=1) |
|
model_input_ids = history_token_ids[:, -history_max_len:].to(device) |
|
outputs = model.generate( |
|
input_ids=model_input_ids, max_new_tokens=max_new_tokens, do_sample=True, top_p=top_p, |
|
temperature=temperature, repetition_penalty=repetition_penalty, eos_token_id=tokenizer.eos_token_id |
|
) |
|
model_input_ids_len = model_input_ids.size(1) |
|
response_ids = outputs[:, model_input_ids_len:] |
|
history_token_ids = torch.concat((history_token_ids, response_ids.cpu()), dim=1) |
|
response = tokenizer.batch_decode(response_ids) |
|
print("Firefly:" + response[0].strip().replace('</s>', "")) |
|
user_input = input('User:') |
|
``` |
|
|
|
|