|
|
|
|
|
|
|
from typing import List |
|
|
|
import fire |
|
|
|
from llama import Llama |
|
import json |
|
|
|
def read_json(file_path): |
|
with open(file_path, 'r', encoding='utf-8') as file: |
|
data = json.load(file) |
|
return data |
|
|
|
def write_json(file_path, data): |
|
with open(file_path, 'w', encoding='utf-8') as file: |
|
json.dump(data, file, ensure_ascii=False, indent=4) |
|
|
|
def main( |
|
ckpt_dir: str, |
|
tokenizer_path: str, |
|
temperature: float = 0.6, |
|
top_p: float = 0.9, |
|
max_seq_len: int = 128, |
|
max_gen_len: int = 64, |
|
max_batch_size: int = 4, |
|
json_path: str = None, |
|
): |
|
""" |
|
Examples to run with the pre-trained models (no fine-tuning). Prompts are |
|
usually in the form of an incomplete text prefix that the model can then try to complete. |
|
|
|
The context window of llama3 models is 8192 tokens, so `max_seq_len` needs to be <= 8192. |
|
`max_gen_len` is needed because pre-trained models usually do not stop completions naturally. |
|
""" |
|
generator = Llama.build( |
|
ckpt_dir=ckpt_dir, |
|
tokenizer_path=tokenizer_path, |
|
max_seq_len=max_seq_len, |
|
max_batch_size=max_batch_size, |
|
) |
|
with open(json_path) as f: |
|
data = json.load(f) |
|
|
|
ans = [] |
|
begin, end,batch_size = 0,len(data),max_batch_size |
|
for batch_idx in tqdm(range(begin, end, max_batch_size)): |
|
up = min(batch_idx + max_batch_size, end) |
|
batch = data[batch_idx:up] |
|
print(f"batch {batch_idx} to {up}") |
|
|
|
text_batch = [] |
|
for idx,i in enumerate(batch): |
|
text_batch.append(idx) |
|
res = generator.text_completion( |
|
text_batch, |
|
max_gen_len=max_gen_len, |
|
temperature=temperature, |
|
top_p=top_p, |
|
) |
|
ans.append(res) |
|
cnt = cnt + 1 |
|
if cnt % 10 == 0: |
|
print(f"batch {cnt} done") |
|
write_json(ans, "ans.json") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
fire.Fire(main) |
|
|