File size: 8,143 Bytes
5593fd8 837c5aa 5593fd8 837c5aa 5593fd8 837c5aa 5593fd8 837c5aa 5593fd8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 |
使用[Firefly](https://github.com/yangjianxin1/Firefly)项目微调baichuan-13b-base。训练数据约为一百万多轮对话数据,包括项目分享的moss数据+2万条school math数据。
更多详情见项目:[Firefly](https://github.com/yangjianxin1/Firefly)
技术细节分享:[Firefly增强Baichuan-13B的多轮对话能力](https://mp.weixin.qq.com/s/djO8Tg3emmy6wzw_rTUlcw)
训练loss:
[//]: # (<img src="https://huggingface.co/YeungNLP/firefly-baichuan-13b/resolve/main/firefly-baichuan-13b-loss.jpg" width="450">)

C-Eval榜单:
| Model | C-Eval | STEM | Social Science | Humanities | Other |
|----------------------------------|--------|-------|----------------|------------|-------|
| Baichuan-13B-Chat(官方) | 52.05 | 42.23 | 65.27 | 58.61 | 51.32 |
| **firefly-baichuan-13b** | 51.36 | 44.24 | 61.65 | 54.63 | 51.68 |
| chatglm2-6b(官方) | 50.45 | 41.91 | 60.73 | 59.24 | 47.82 |
| **firefly-chatglm2-6b** | 49.13 | 43.6 | 58.83 | 54.48 | 45.03 |
| openbuddy-llama2-13b-v11.1-bf16 | 43.36 | 39.79 | 50.28 | 44.78 | 42.13 |
| chinese-alpaca-2-13b(哈工大) | 41.86 | 36.52 | 49.7 | 47.97 | 38.33 |
| openbuddy-llama2-13b-v8.1-fp16 | 41.62 | 38.82 | 44.66 | 40.28 | 45.32 |
| chinese-alpaca-2-7b(哈工大) | 41.48 | 35.01 | 50.08 | 43.02 | 43.87 |
| belle-llama2-13B-chat-0.4M | 41.11 | 40.04 | 44.71 | 42.09 | 38.82 |
| ziya-llama-13b | 39.1 | - | - | - | - |
| llama-2-13b-chat(官方) | 36.38 | 33.68 | 46.38 | 34.47 | 34.1 |
| lama-2-7b-chat(官方) | 35.86 | 32.85 | 40.04 | 37.37 | 36.01 |
| flagalpha/Llama2-Chinese-7b-Chat | 34.54 | 35.21 | 37.9 | 33.11 | 31.7 |
| yayi-13b-llama2 | 34.15 | 36.48 | 30.64 | 32.67 | 34.6 |
| yayi-7b-llama2 | 30.18 | 25.88 | 38.23 | 34.56 | 26.31 |
| linly-llama2-7b | 28.35 | 26.06 | 33.47 | 29.71 | 26.53 |
| linly-llama2-13b | 27.86 | 27.67 | 26.95 | 27.93 | 28.95 |
单轮对话:
```python
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
"""
单轮对话,不具有对话历史的记忆功能
"""
def main():
model_name = 'YeungNLP/firefly-baichuan-13b'
max_new_tokens = 500
top_p = 0.9
temperature = 0.35
repetition_penalty = 1.0
device = 'cuda'
model = AutoModelForCausalLM.from_pretrained(
model_name,
trust_remote_code=True,
low_cpu_mem_usage=True,
torch_dtype=torch.float16,
device_map='auto'
).to(device).eval()
tokenizer = AutoTokenizer.from_pretrained(
model_name,
trust_remote_code=True,
# llama不支持fast
use_fast=False if model.config.model_type == 'llama' else True
)
# QWenTokenizer比较特殊,pad_token_id、bos_token_id、eos_token_id均为None。eod_id对应的token为<|endoftext|>
if tokenizer.__class__.__name__ == 'QWenTokenizer':
tokenizer.pad_token_id = tokenizer.eod_id
tokenizer.bos_token_id = tokenizer.eod_id
tokenizer.eos_token_id = tokenizer.eod_id
text = input('User:')
while True:
text = text.strip()
# chatglm使用官方的数据组织格式
if model.config.model_type == 'chatglm':
text = '[Round 1]\n\n问:{}\n\n答:'.format(text)
input_ids = tokenizer(text, return_tensors="pt", add_special_tokens=False).input_ids.to(device)
# 为了兼容qwen-7b,因为其对eos_token进行tokenize,无法得到对应的eos_token_id
else:
input_ids = tokenizer(text, return_tensors="pt", add_special_tokens=False).input_ids.to(device)
bos_token_id = torch.tensor([[tokenizer.bos_token_id]], dtype=torch.long).to(device)
eos_token_id = torch.tensor([[tokenizer.eos_token_id]], dtype=torch.long).to(device)
input_ids = torch.concat([bos_token_id, input_ids, eos_token_id], dim=1)
with torch.no_grad():
outputs = model.generate(
input_ids=input_ids, max_new_tokens=max_new_tokens, do_sample=True,
top_p=top_p, temperature=temperature, repetition_penalty=repetition_penalty,
eos_token_id=tokenizer.eos_token_id
)
outputs = outputs.tolist()[0][len(input_ids[0]):]
response = tokenizer.decode(outputs)
response = response.strip().replace(tokenizer.eos_token, "").strip()
print("Firefly:{}".format(response))
text = input('User:')
if __name__ == '__main__':
main()
```
多轮对话:
```python
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
def main():
model_name = 'YeungNLP/firefly-baichuan-13b'
device = 'cuda'
max_new_tokens = 500 # 每轮对话最多生成多少个token
history_max_len = 1000 # 模型记忆的最大token长度
top_p = 0.9
temperature = 0.35
repetition_penalty = 1.0
# 加载模型
model = AutoModelForCausalLM.from_pretrained(
model_name,
trust_remote_code=True,
low_cpu_mem_usage=True,
torch_dtype=torch.float16,
device_map='auto'
).to(device).eval()
tokenizer = AutoTokenizer.from_pretrained(
model_name,
trust_remote_code=True,
# llama不支持fast
use_fast=False if model.config.model_type == 'llama' else True
)
# QWenTokenizer比较特殊,pad_token_id、bos_token_id、eos_token_id均为None。eod_id对应的token为<|endoftext|>
if tokenizer.__class__.__name__ == 'QWenTokenizer':
tokenizer.pad_token_id = tokenizer.eod_id
tokenizer.bos_token_id = tokenizer.eod_id
tokenizer.eos_token_id = tokenizer.eod_id
# 记录所有历史记录
if model.config.model_type != 'chatglm':
history_token_ids = torch.tensor([[tokenizer.bos_token_id]], dtype=torch.long)
else:
history_token_ids = torch.tensor([[]], dtype=torch.long)
# 开始对话
utterance_id = 0 # 记录当前是第几轮对话,为了契合chatglm的数据组织格式
user_input = input('User:')
while True:
utterance_id += 1
# chatglm使用官方的数据组织格式
if model.config.model_type == 'chatglm':
user_input = '[Round {}]\n\n问:{}\n\n答:'.format(utterance_id, user_input)
user_input_ids = tokenizer(user_input, return_tensors="pt", add_special_tokens=False).input_ids
# firefly的数据组织格式
# 为了兼容qwen-7b,因为其对eos_token进行tokenize,无法得到对应的eos_token_id
else:
input_ids = tokenizer(user_input, return_tensors="pt", add_special_tokens=False).input_ids
eos_token_id = torch.tensor([[tokenizer.eos_token_id]], dtype=torch.long)
user_input_ids = torch.concat([input_ids, eos_token_id], dim=1)
history_token_ids = torch.concat((history_token_ids, user_input_ids), dim=1)
model_input_ids = history_token_ids[:, -history_max_len:].to(device)
with torch.no_grad():
outputs = model.generate(
input_ids=model_input_ids, max_new_tokens=max_new_tokens, do_sample=True, top_p=top_p,
temperature=temperature, repetition_penalty=repetition_penalty, eos_token_id=tokenizer.eos_token_id
)
model_input_ids_len = model_input_ids.size(1)
response_ids = outputs[:, model_input_ids_len:]
history_token_ids = torch.concat((history_token_ids, response_ids.cpu()), dim=1)
response = tokenizer.batch_decode(response_ids)
print("Firefly:" + response[0].strip().replace(tokenizer.eos_token, ""))
user_input = input('User:')
if __name__ == '__main__':
main()
```
|