Wendy-Fly commited on
Commit
cced866
·
verified ·
1 Parent(s): 5a10cc1

Upload generate_prompt.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. generate_prompt.py +2 -1
generate_prompt.py CHANGED
@@ -99,9 +99,10 @@ for batch_idx in tqdm(range(begin, end, batch_size)):
99
  return_tensors="pt",
100
  )
101
  inputs = inputs.to(model.device)
102
- print(inputs.shape)
103
  # Inference: Generation of the output
104
  generated_ids = model.generate(**inputs, max_new_tokens=128)
 
105
  generated_ids_trimmed = [
106
  out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
107
  ]
 
99
  return_tensors="pt",
100
  )
101
  inputs = inputs.to(model.device)
102
+
103
  # Inference: Generation of the output
104
  generated_ids = model.generate(**inputs, max_new_tokens=128)
105
+ print(generated_ids.shape)
106
  generated_ids_trimmed = [
107
  out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
108
  ]