Spaces:
Runtime error
Runtime error
Commit
ยท
df82baf
1
Parent(s):
4983843
add temperature to response generation
Browse files
src/synthetic_dataset_generator/pipelines/sft.py
CHANGED
@@ -207,7 +207,7 @@ def get_magpie_generator(system_prompt, num_turns, temperature, is_sample):
|
|
207 |
return magpie_generator
|
208 |
|
209 |
|
210 |
-
def get_response_generator(system_prompt, num_turns, is_sample):
|
211 |
if num_turns == 1:
|
212 |
response_generator = TextGeneration(
|
213 |
llm=InferenceEndpointsLLM(
|
@@ -216,7 +216,7 @@ def get_response_generator(system_prompt, num_turns, is_sample):
|
|
216 |
base_url=BASE_URL,
|
217 |
api_key=_get_next_api_key(),
|
218 |
generation_kwargs={
|
219 |
-
"temperature":
|
220 |
"max_new_tokens": 256 if is_sample else 1024,
|
221 |
},
|
222 |
),
|
@@ -232,7 +232,7 @@ def get_response_generator(system_prompt, num_turns, is_sample):
|
|
232 |
base_url=BASE_URL,
|
233 |
api_key=_get_next_api_key(),
|
234 |
generation_kwargs={
|
235 |
-
"temperature":
|
236 |
"max_new_tokens": 2048,
|
237 |
},
|
238 |
),
|
|
|
207 |
return magpie_generator
|
208 |
|
209 |
|
210 |
+
def get_response_generator(system_prompt, num_turns, temperature, is_sample):
|
211 |
if num_turns == 1:
|
212 |
response_generator = TextGeneration(
|
213 |
llm=InferenceEndpointsLLM(
|
|
|
216 |
base_url=BASE_URL,
|
217 |
api_key=_get_next_api_key(),
|
218 |
generation_kwargs={
|
219 |
+
"temperature": temperature,
|
220 |
"max_new_tokens": 256 if is_sample else 1024,
|
221 |
},
|
222 |
),
|
|
|
232 |
base_url=BASE_URL,
|
233 |
api_key=_get_next_api_key(),
|
234 |
generation_kwargs={
|
235 |
+
"temperature": temperature,
|
236 |
"max_new_tokens": 2048,
|
237 |
},
|
238 |
),
|