import transformers import torch def write_json(file_path, data): with open(file_path, 'w', encoding='utf-8') as file: json.dump(data, file, ensure_ascii=False, indent=4) model_id = "/inspire/hdd/ws-ba572160-47f8-4ca1-984e-d6bcdeb95dbb/a100-maybe/albus/ICCV_2025/LLama3_70B/Llama-3.3-70B-Instruct" pipeline = transformers.pipeline( "text-generation", model=model_id, model_kwargs={"torch_dtype": torch.bfloat16}, device_map="auto", ) messages = [ {"role": "system", "content": "You are a pirate chatbot who always responds in pirate speak!"}, {"role": "user", "content": "Who are you?"}, ] json_path = "/inspire/hdd/ws-ba572160-47f8-4ca1-984e-d6bcdeb95dbb/a100-maybe/albus/ICCV_2025/LLama3_70B/llama3/merged_data.json" with open(json_path) as f: data = json.load(f) ans = [] begin, end,batch_size = 0,len(data),4 cnt = 0 for batch_idx in tqdm(range(begin, end, max_batch_size)): up = min(batch_idx + max_batch_size, end) batch = data[batch_idx:up] print(f"batch {batch_idx} to {up}") text_batch = [] for idx,i in enumerate(batch): text_batch.append(idx) outputs = pipeline(messages,max_new_tokens=2048,)[0]["generated_text"] ans.append(outputs) cnt += 1 if cnt % 10 == 0: print(f"batch {cnt} done") write_json(ans, "ans.json")