Wendy-Fly commited on
Commit
62cd74f
·
verified ·
1 Parent(s): b327cc4

Upload generate_prompt.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. generate_prompt.py +25 -26
generate_prompt.py CHANGED
@@ -88,32 +88,31 @@ for batch_idx in tqdm(range(begin, end, batch_size)):
88
  data_list.append(messages)
89
  save_list.append(save_)
90
  #print(len(data_list))
91
- text = processor.apply_chat_template(data_list, tokenize=False, add_generation_prompt=True)
92
- #print(len(text))
93
- image_inputs, video_inputs = process_vision_info(data_list)
94
- inputs = processor(
95
- text=[text],
96
- images=image_inputs,
97
- videos=video_inputs,
98
- padding=True,
99
- return_tensors="pt",
100
- )
101
- inputs = inputs.to(model.device)
102
-
103
- # Inference: Generation of the output
104
- generated_ids = model.generate(**inputs, max_new_tokens=128)
105
- #print(generated_ids.shape)
106
- generated_ids_trimmed = [
107
- out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
108
- ]
109
- output_text = processor.batch_decode(
110
- generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
111
- )
112
- print(output_text)
113
- for idx,x in enumerate(output_text):
114
- idx_real = batch_idx * batch_size + idx
115
- save_list[idx][0]['result'] = x
116
- save_data.append(save_list[idx])
117
  if batch_idx % 4 ==0:
118
  write_json(json_path,save_data)
119
  print(len(save_data))
 
88
  data_list.append(messages)
89
  save_list.append(save_)
90
  #print(len(data_list))
91
+ text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
92
+ #print(len(text))
93
+ image_inputs, video_inputs = process_vision_info(messages)
94
+ inputs = processor(
95
+ text=[text],
96
+ images=image_inputs,
97
+ videos=video_inputs,
98
+ padding=True,
99
+ return_tensors="pt",
100
+ )
101
+ inputs = inputs.to(model.device)
102
+
103
+ # Inference: Generation of the output
104
+ generated_ids = model.generate(**inputs, max_new_tokens=128)
105
+ #print(generated_ids.shape)
106
+ generated_ids_trimmed = [
107
+ out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
108
+ ]
109
+ output_text = processor.batch_decode(
110
+ generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
111
+ )
112
+ #print(output_text)
113
+ save_[0]['result'] = x
114
+ save_data.append(save_)
115
+
 
116
  if batch_idx % 4 ==0:
117
  write_json(json_path,save_data)
118
  print(len(save_data))