|
import transformers |
|
import torch |
|
|
|
|
|
def write_json(file_path, data): |
|
with open(file_path, 'w', encoding='utf-8') as file: |
|
json.dump(data, file, ensure_ascii=False, indent=4) |
|
model_id = "/inspire/hdd/ws-ba572160-47f8-4ca1-984e-d6bcdeb95dbb/a100-maybe/albus/ICCV_2025/LLama3_70B/Llama-3.3-70B-Instruct" |
|
|
|
pipeline = transformers.pipeline( |
|
"text-generation", |
|
model=model_id, |
|
model_kwargs={"torch_dtype": torch.bfloat16}, |
|
device_map="auto", |
|
) |
|
|
|
messages = [ |
|
{"role": "system", "content": "You are a pirate chatbot who always responds in pirate speak!"}, |
|
{"role": "user", "content": "Who are you?"}, |
|
] |
|
|
|
json_path = "/inspire/hdd/ws-ba572160-47f8-4ca1-984e-d6bcdeb95dbb/a100-maybe/albus/ICCV_2025/LLama3_70B/llama3/merged_data.json" |
|
|
|
with open(json_path) as f: |
|
data = json.load(f) |
|
ans = [] |
|
begin, end,batch_size = 0,len(data),4 |
|
cnt = 0 |
|
for batch_idx in tqdm(range(begin, end, max_batch_size)): |
|
up = min(batch_idx + max_batch_size, end) |
|
batch = data[batch_idx:up] |
|
print(f"batch {batch_idx} to {up}") |
|
text_batch = [] |
|
for idx,i in enumerate(batch): |
|
text_batch.append(idx) |
|
outputs = pipeline(messages,max_new_tokens=2048,)[0]["generated_text"] |
|
ans.append(outputs) |
|
cnt += 1 |
|
if cnt % 10 == 0: |
|
print(f"batch {cnt} done") |
|
write_json(ans, "ans.json") |
|
|