File size: 1,385 Bytes
e96c59c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 |
import transformers
import torch
def write_json(file_path, data):
with open(file_path, 'w', encoding='utf-8') as file:
json.dump(data, file, ensure_ascii=False, indent=4)
model_id = "/inspire/hdd/ws-ba572160-47f8-4ca1-984e-d6bcdeb95dbb/a100-maybe/albus/ICCV_2025/LLama3_70B/Llama-3.3-70B-Instruct"
pipeline = transformers.pipeline(
"text-generation",
model=model_id,
model_kwargs={"torch_dtype": torch.bfloat16},
device_map="auto",
)
messages = [
{"role": "system", "content": "You are a pirate chatbot who always responds in pirate speak!"},
{"role": "user", "content": "Who are you?"},
]
json_path = "/inspire/hdd/ws-ba572160-47f8-4ca1-984e-d6bcdeb95dbb/a100-maybe/albus/ICCV_2025/LLama3_70B/llama3/merged_data.json"
with open(json_path) as f:
data = json.load(f)
ans = []
begin, end,batch_size = 0,len(data),4
cnt = 0
for batch_idx in tqdm(range(begin, end, max_batch_size)):
up = min(batch_idx + max_batch_size, end)
batch = data[batch_idx:up]
print(f"batch {batch_idx} to {up}")
text_batch = []
for idx,i in enumerate(batch):
text_batch.append(idx)
outputs = pipeline(messages,max_new_tokens=2048,)[0]["generated_text"]
ans.append(outputs)
cnt += 1
if cnt % 10 == 0:
print(f"batch {cnt} done")
write_json(ans, "ans.json")
|