Wendy-Fly commited on
Commit
fe4c470
·
verified ·
1 Parent(s): 45050fd

Upload example_text_completion.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. example_text_completion.py +99 -0
example_text_completion.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
3
+
4
+ from typing import List
5
+
6
+ import fire
7
+
8
+ from llama import Llama
9
+ import json
10
+
11
+ def read_json(file_path):
12
+ with open(file_path, 'r', encoding='utf-8') as file:
13
+ data = json.load(file)
14
+ return data
15
+
16
+ def write_json(file_path, data):
17
+ with open(file_path, 'w', encoding='utf-8') as file:
18
+ json.dump(data, file, ensure_ascii=False, indent=4)
19
+
20
+ def main(
21
+ ckpt_dir: str,
22
+ tokenizer_path: str,
23
+ temperature: float = 0.6,
24
+ top_p: float = 0.9,
25
+ max_seq_len: int = 128,
26
+ max_gen_len: int = 64,
27
+ max_batch_size: int = 4,
28
+ json_path: str = None,
29
+ ):
30
+ """
31
+ Examples to run with the pre-trained models (no fine-tuning). Prompts are
32
+ usually in the form of an incomplete text prefix that the model can then try to complete.
33
+
34
+ The context window of llama3 models is 8192 tokens, so `max_seq_len` needs to be <= 8192.
35
+ `max_gen_len` is needed because pre-trained models usually do not stop completions naturally.
36
+ """
37
+ generator = Llama.build(
38
+ ckpt_dir=ckpt_dir,
39
+ tokenizer_path=tokenizer_path,
40
+ max_seq_len=max_seq_len,
41
+ max_batch_size=max_batch_size,
42
+ )
43
+ with open(json_path) as f:
44
+ data = json.load(f)
45
+
46
+ ans = []
47
+ begin, end,batch_size = 0,len(data),max_batch_size
48
+ for batch_idx in tqdm(range(begin, end, max_batch_size)):
49
+ up = min(batch_idx + max_batch_size, end)
50
+ batch = data[batch_idx:up]
51
+ print(f"batch {batch_idx} to {up}")
52
+
53
+ text_batch = []
54
+ for idx,i in enumerate(batch):
55
+ text_batch.append(idx)
56
+ res = generator.text_completion(
57
+ text_batch,
58
+ max_gen_len=max_gen_len,
59
+ temperature=temperature,
60
+ top_p=top_p,
61
+ )
62
+ ans.append(res)
63
+ cnt = cnt + 1
64
+ if cnt % 10 == 0:
65
+ print(f"batch {cnt} done")
66
+ write_json(ans, "ans.json")
67
+
68
+
69
+ # prompts: List[str] = [
70
+ # # For these prompts, the expected answer is the natural continuation of the prompt
71
+ # "I believe the meaning of life is",
72
+ # "Simply put, the theory of relativity states that ",
73
+ # """A brief message congratulating the team on the launch:
74
+
75
+ # Hi everyone,
76
+
77
+ # I just """,
78
+ # # Few shot prompt (providing a few examples before asking model to complete more);
79
+ # """Translate English to French:
80
+
81
+ # sea otter => loutre de mer
82
+ # peppermint => menthe poivrée
83
+ # plush girafe => girafe peluche
84
+ # cheese =>""",
85
+ # ]
86
+ # results = generator.text_completion(
87
+ # prompts,
88
+ # max_gen_len=max_gen_len,
89
+ # temperature=temperature,
90
+ # top_p=top_p,
91
+ # )
92
+ # for prompt, result in zip(prompts, results):
93
+ # print(prompt)
94
+ # print(f"> {result['generation']}")
95
+ # print("\n==================================\n")
96
+
97
+
98
+ if __name__ == "__main__":
99
+ fire.Fire(main)