davidberenstein1957 HF staff commited on
Commit
79801ad
·
unverified ·
2 Parent(s): d5933e1 df82baf

Merge pull request #7 from argilla-io/feat/improve-sft

Browse files
src/synthetic_dataset_generator/pipelines/sft.py CHANGED
@@ -207,7 +207,7 @@ def get_magpie_generator(system_prompt, num_turns, temperature, is_sample):
207
  return magpie_generator
208
 
209
 
210
- def get_response_generator(system_prompt, num_turns, is_sample):
211
  if num_turns == 1:
212
  response_generator = TextGeneration(
213
  llm=InferenceEndpointsLLM(
@@ -216,7 +216,7 @@ def get_response_generator(system_prompt, num_turns, is_sample):
216
  base_url=BASE_URL,
217
  api_key=_get_next_api_key(),
218
  generation_kwargs={
219
- "temperature": 0.8,
220
  "max_new_tokens": 256 if is_sample else 1024,
221
  },
222
  ),
@@ -232,7 +232,7 @@ def get_response_generator(system_prompt, num_turns, is_sample):
232
  base_url=BASE_URL,
233
  api_key=_get_next_api_key(),
234
  generation_kwargs={
235
- "temperature": 0.8,
236
  "max_new_tokens": 2048,
237
  },
238
  ),
 
207
  return magpie_generator
208
 
209
 
210
+ def get_response_generator(system_prompt, num_turns, temperature, is_sample):
211
  if num_turns == 1:
212
  response_generator = TextGeneration(
213
  llm=InferenceEndpointsLLM(
 
216
  base_url=BASE_URL,
217
  api_key=_get_next_api_key(),
218
  generation_kwargs={
219
+ "temperature": temperature,
220
  "max_new_tokens": 256 if is_sample else 1024,
221
  },
222
  ),
 
232
  base_url=BASE_URL,
233
  api_key=_get_next_api_key(),
234
  generation_kwargs={
235
+ "temperature": temperature,
236
  "max_new_tokens": 2048,
237
  },
238
  ),