Aman commited on
Commit
0fcc8ab
·
verified ·
1 Parent(s): bd9991a

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +43 -16
README.md CHANGED
@@ -11,27 +11,54 @@ This model is a 7B Chinese version of [Self-RAG](https://huggingface.co/selfrag/
11
  It is trained on Baichuan2-7B-Chat with a sample of [belle](https://github.com/LianjiaTech/BELLE) sft data, acompanying with interleaving passages from zhwiki. The reflection tokens are aligned with the original verison (in English), so the usage is the same. Hope you enjoy.
12
 
13
  ### Usage
 
 
14
  ```
15
- from transformers import AutoTokenizer, AutoModelForCausalLM
16
- from vllm import LLM, SamplingParams
 
 
 
 
 
 
 
 
 
 
17
 
18
- model = LLM(YOUR_MODEL_PATH, dtype="half")
19
- sampling_params = SamplingParams(temperature=0.0, top_p=1.0, max_tokens=100, skip_special_tokens=False)
20
 
21
  def format_prompt(input, paragraph=None):
22
- prompt = "### Instruction:\n{0}\n\n### Response:\n".format(input)
23
- if paragraph is not None:
24
- prompt += "[Retrieval]<paragraph>{0}</paragraph>".format(paragraph)
25
- return prompt
26
-
27
- query_1 = "你好呀"
28
- query_2 = "故宫三大殿是哪些?"
29
- queries = [query_1, query_2]
30
-
31
- preds = model.generate([format_prompt(query) for query in queries], sampling_params)
32
- for pred in preds:
33
- print("Model prediction: {0}".format(pred.outputs[0].text))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  # Model prediction: [No Retrieval] 你好!有什么我可以帮你解答的问题吗? [Utility:5] </s>
 
35
  # Model prediction: [Retrieval] <paragraph> ... (this query requires factual grounding, call a retriever) </paragraph> [Relevant] 太和殿、中和殿、保和殿 [Utility:5] </s>
36
  ```
37
 
 
11
  It is trained on Baichuan2-7B-Chat with a sample of [belle](https://github.com/LianjiaTech/BELLE) sft data, acompanying with interleaving passages from zhwiki. The reflection tokens are aligned with the original verison (in English), so the usage is the same. Hope you enjoy.
12
 
13
  ### Usage
14
+ I found some output errors while adopting vllm to accelerate the generation process and not sure whether it is due to some precision issues.
15
+ This may be owing to the implementation of vllm. Thus, I use the original generate method of transformers.
16
  ```
17
+ import os, torch
18
+ from transformers import AutoModelForCausalLM, AutoTokenizer
19
+
20
+ tokenizer = AutoTokenizer.from_pretrained(YOUR_TOKENIZER_PATH)
21
+ model = AutoModelForCausalLM.from_pretrained(
22
+ YOUR_MODEL_PATH,
23
+ torch_dtype=torch.bfloat16,
24
+ device_map="cuda",
25
+ )
26
+
27
+ ### set your retriever if necessary
28
+ retriever = setup_retriever(YOUR_RETRIEVER_PATH)
29
 
 
 
30
 
31
  def format_prompt(input, paragraph=None):
32
+ prompt = "### Instruction:\n{0}\n\n### Response:".format(input)
33
+ if paragraph is not None:
34
+ prompt += "[Retrieval]<paragraph>{0}</paragraph>".format(paragraph)
35
+ return prompt
36
+
37
+
38
+ while True:
39
+ query = input("[Human]: ")
40
+ prompt = format_prompt(query)
41
+ sequences = model.generate(
42
+ **tokenizer(prompt, return_tensors='pt').to(model.device),
43
+ do_sample=False,
44
+ num_beams=5,
45
+ # top_k=10,
46
+ # top_p=0.8,
47
+ temperature=0.9,
48
+ num_return_sequences=1,
49
+ eos_token_id=tokenizer.eos_token_id,
50
+ max_new_tokens=1024,
51
+ min_new_tokens=1,
52
+ repetition_penalty=1.5,
53
+ )
54
+ for seq in sequences:
55
+ print(f"[Model]: {tokenizer.decode(seq, skip_special_tokens=False)}")
56
+ print("-"*50)
57
+ print("="*50)
58
+
59
+ # query_1 = "你好呀"
60
  # Model prediction: [No Retrieval] 你好!有什么我可以帮你解答的问题吗? [Utility:5] </s>
61
+ # query_2 = "故宫三大殿是哪些?"
62
  # Model prediction: [Retrieval] <paragraph> ... (this query requires factual grounding, call a retriever) </paragraph> [Relevant] 太和殿、中和殿、保和殿 [Utility:5] </s>
63
  ```
64