davidberenstein1957 HF staff commited on
Commit
df82baf
ยท
1 Parent(s): 4983843

add temperature to response generation

Browse files
src/synthetic_dataset_generator/pipelines/sft.py CHANGED
@@ -207,7 +207,7 @@ def get_magpie_generator(system_prompt, num_turns, temperature, is_sample):
207
  return magpie_generator
208
 
209
 
210
- def get_response_generator(system_prompt, num_turns, is_sample):
211
  if num_turns == 1:
212
  response_generator = TextGeneration(
213
  llm=InferenceEndpointsLLM(
@@ -216,7 +216,7 @@ def get_response_generator(system_prompt, num_turns, is_sample):
216
  base_url=BASE_URL,
217
  api_key=_get_next_api_key(),
218
  generation_kwargs={
219
- "temperature": 0.8,
220
  "max_new_tokens": 256 if is_sample else 1024,
221
  },
222
  ),
@@ -232,7 +232,7 @@ def get_response_generator(system_prompt, num_turns, is_sample):
232
  base_url=BASE_URL,
233
  api_key=_get_next_api_key(),
234
  generation_kwargs={
235
- "temperature": 0.8,
236
  "max_new_tokens": 2048,
237
  },
238
  ),
 
207
  return magpie_generator
208
 
209
 
210
+ def get_response_generator(system_prompt, num_turns, temperature, is_sample):
211
  if num_turns == 1:
212
  response_generator = TextGeneration(
213
  llm=InferenceEndpointsLLM(
 
216
  base_url=BASE_URL,
217
  api_key=_get_next_api_key(),
218
  generation_kwargs={
219
+ "temperature": temperature,
220
  "max_new_tokens": 256 if is_sample else 1024,
221
  },
222
  ),
 
232
  base_url=BASE_URL,
233
  api_key=_get_next_api_key(),
234
  generation_kwargs={
235
+ "temperature": temperature,
236
  "max_new_tokens": 2048,
237
  },
238
  ),