Wendy-Fly commited on
Commit
e96c59c
·
verified ·
1 Parent(s): fe4c470

Upload infer_llama.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. infer_llama.py +41 -0
infer_llama.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import transformers
2
+ import torch
3
+
4
+
5
+ def write_json(file_path, data):
6
+ with open(file_path, 'w', encoding='utf-8') as file:
7
+ json.dump(data, file, ensure_ascii=False, indent=4)
8
+ model_id = "/inspire/hdd/ws-ba572160-47f8-4ca1-984e-d6bcdeb95dbb/a100-maybe/albus/ICCV_2025/LLama3_70B/Llama-3.3-70B-Instruct"
9
+
10
+ pipeline = transformers.pipeline(
11
+ "text-generation",
12
+ model=model_id,
13
+ model_kwargs={"torch_dtype": torch.bfloat16},
14
+ device_map="auto",
15
+ )
16
+
17
+ messages = [
18
+ {"role": "system", "content": "You are a pirate chatbot who always responds in pirate speak!"},
19
+ {"role": "user", "content": "Who are you?"},
20
+ ]
21
+
22
+ json_path = "/inspire/hdd/ws-ba572160-47f8-4ca1-984e-d6bcdeb95dbb/a100-maybe/albus/ICCV_2025/LLama3_70B/llama3/merged_data.json"
23
+
24
+ with open(json_path) as f:
25
+ data = json.load(f)
26
+ ans = []
27
+ begin, end,batch_size = 0,len(data),4
28
+ cnt = 0
29
+ for batch_idx in tqdm(range(begin, end, max_batch_size)):
30
+ up = min(batch_idx + max_batch_size, end)
31
+ batch = data[batch_idx:up]
32
+ print(f"batch {batch_idx} to {up}")
33
+ text_batch = []
34
+ for idx,i in enumerate(batch):
35
+ text_batch.append(idx)
36
+ outputs = pipeline(messages,max_new_tokens=2048,)[0]["generated_text"]
37
+ ans.append(outputs)
38
+ cnt += 1
39
+ if cnt % 10 == 0:
40
+ print(f"batch {cnt} done")
41
+ write_json(ans, "ans.json")