Spaces:
Runtime error
Runtime error
Merge pull request #28 from argilla-io/bug/fix-bugs
Browse files- src/synthetic_dataset_generator/apps/base.py +2 -2
- src/synthetic_dataset_generator/apps/chat.py +81 -76
- src/synthetic_dataset_generator/apps/eval.py +1 -0
- src/synthetic_dataset_generator/apps/rag.py +6 -5
- src/synthetic_dataset_generator/apps/textcat.py +16 -10
- src/synthetic_dataset_generator/pipelines/chat.py +6 -4
- src/synthetic_dataset_generator/pipelines/eval.py +13 -3
- src/synthetic_dataset_generator/pipelines/rag.py +3 -4
- src/synthetic_dataset_generator/pipelines/textcat.py +4 -2
src/synthetic_dataset_generator/apps/base.py
CHANGED
|
@@ -64,7 +64,7 @@ def push_pipeline_code_to_hub(
|
|
| 64 |
progress(1.0, desc="Pipeline code uploaded")
|
| 65 |
|
| 66 |
|
| 67 |
-
def validate_push_to_hub(org_name, repo_name):
|
| 68 |
repo_id = (
|
| 69 |
f"{org_name}/{repo_name}"
|
| 70 |
if repo_name is not None and org_name is not None
|
|
@@ -93,7 +93,7 @@ def combine_datasets(
|
|
| 93 |
return dataset
|
| 94 |
|
| 95 |
|
| 96 |
-
def show_success_message(org_name, repo_name) -> gr.Markdown:
|
| 97 |
client = get_argilla_client()
|
| 98 |
if client is None:
|
| 99 |
return gr.Markdown(
|
|
|
|
| 64 |
progress(1.0, desc="Pipeline code uploaded")
|
| 65 |
|
| 66 |
|
| 67 |
+
def validate_push_to_hub(org_name: str, repo_name: str):
|
| 68 |
repo_id = (
|
| 69 |
f"{org_name}/{repo_name}"
|
| 70 |
if repo_name is not None and org_name is not None
|
|
|
|
| 93 |
return dataset
|
| 94 |
|
| 95 |
|
| 96 |
+
def show_success_message(org_name: str, repo_name: str) -> gr.Markdown:
|
| 97 |
client = get_argilla_client()
|
| 98 |
if client is None:
|
| 99 |
return gr.Markdown(
|
src/synthetic_dataset_generator/apps/chat.py
CHANGED
|
@@ -60,7 +60,7 @@ def convert_dataframe_messages(dataframe: pd.DataFrame) -> pd.DataFrame:
|
|
| 60 |
return dataframe
|
| 61 |
|
| 62 |
|
| 63 |
-
def generate_system_prompt(dataset_description, progress=gr.Progress()):
|
| 64 |
progress(0.1, desc="Initializing")
|
| 65 |
generate_description = get_prompt_generator()
|
| 66 |
progress(0.5, desc="Generating")
|
|
@@ -77,7 +77,7 @@ def generate_system_prompt(dataset_description, progress=gr.Progress()):
|
|
| 77 |
return result
|
| 78 |
|
| 79 |
|
| 80 |
-
def generate_sample_dataset(system_prompt, num_turns, progress=gr.Progress()):
|
| 81 |
progress(0.1, desc="Generating sample dataset")
|
| 82 |
dataframe = generate_dataset(
|
| 83 |
system_prompt=system_prompt,
|
|
@@ -109,7 +109,7 @@ def generate_dataset(
|
|
| 109 |
num_rows = test_max_num_rows(num_rows)
|
| 110 |
progress(0.0, desc="(1/2) Generating instructions")
|
| 111 |
magpie_generator = get_magpie_generator(
|
| 112 |
-
|
| 113 |
)
|
| 114 |
response_generator = get_response_generator(
|
| 115 |
system_prompt, num_turns, temperature, is_sample
|
|
@@ -267,7 +267,12 @@ def push_dataset(
|
|
| 267 |
temperature=temperature,
|
| 268 |
)
|
| 269 |
push_dataset_to_hub(
|
| 270 |
-
dataframe,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 271 |
)
|
| 272 |
try:
|
| 273 |
progress(0.1, desc="Setting up user and workspace")
|
|
@@ -524,77 +529,77 @@ with gr.Blocks() as app:
|
|
| 524 |
label="Distilabel Pipeline Code",
|
| 525 |
)
|
| 526 |
|
| 527 |
-
|
| 528 |
-
|
| 529 |
-
|
| 530 |
-
|
| 531 |
-
|
| 532 |
-
|
| 533 |
-
|
| 534 |
-
|
| 535 |
-
|
| 536 |
-
|
| 537 |
-
|
| 538 |
|
| 539 |
-
|
| 540 |
-
|
| 541 |
-
|
| 542 |
-
|
| 543 |
-
|
| 544 |
-
|
| 545 |
|
| 546 |
-
|
| 547 |
-
|
| 548 |
-
|
| 549 |
-
|
| 550 |
-
|
| 551 |
-
|
| 552 |
-
|
| 553 |
-
|
| 554 |
-
|
| 555 |
-
|
| 556 |
-
|
| 557 |
-
|
| 558 |
-
|
| 559 |
-
|
| 560 |
-
|
| 561 |
-
|
| 562 |
-
|
| 563 |
-
|
| 564 |
-
|
| 565 |
-
|
| 566 |
-
|
| 567 |
-
|
| 568 |
-
|
| 569 |
-
|
| 570 |
-
|
| 571 |
-
|
| 572 |
-
|
| 573 |
-
|
| 574 |
-
|
| 575 |
-
|
| 576 |
-
|
| 577 |
-
|
| 578 |
-
|
| 579 |
-
|
| 580 |
-
|
| 581 |
-
|
| 582 |
-
|
| 583 |
-
|
| 584 |
-
|
| 585 |
-
|
| 586 |
-
|
| 587 |
-
|
| 588 |
-
|
| 589 |
-
|
| 590 |
-
|
| 591 |
-
|
| 592 |
-
|
| 593 |
-
|
| 594 |
-
|
| 595 |
-
|
| 596 |
-
|
| 597 |
-
|
| 598 |
-
|
| 599 |
-
|
| 600 |
-
|
|
|
|
| 60 |
return dataframe
|
| 61 |
|
| 62 |
|
| 63 |
+
def generate_system_prompt(dataset_description: str, progress=gr.Progress()):
|
| 64 |
progress(0.1, desc="Initializing")
|
| 65 |
generate_description = get_prompt_generator()
|
| 66 |
progress(0.5, desc="Generating")
|
|
|
|
| 77 |
return result
|
| 78 |
|
| 79 |
|
| 80 |
+
def generate_sample_dataset(system_prompt: str, num_turns: int, progress=gr.Progress()):
|
| 81 |
progress(0.1, desc="Generating sample dataset")
|
| 82 |
dataframe = generate_dataset(
|
| 83 |
system_prompt=system_prompt,
|
|
|
|
| 109 |
num_rows = test_max_num_rows(num_rows)
|
| 110 |
progress(0.0, desc="(1/2) Generating instructions")
|
| 111 |
magpie_generator = get_magpie_generator(
|
| 112 |
+
num_turns, temperature, is_sample
|
| 113 |
)
|
| 114 |
response_generator = get_response_generator(
|
| 115 |
system_prompt, num_turns, temperature, is_sample
|
|
|
|
| 267 |
temperature=temperature,
|
| 268 |
)
|
| 269 |
push_dataset_to_hub(
|
| 270 |
+
dataframe=dataframe,
|
| 271 |
+
org_name=org_name,
|
| 272 |
+
repo_name=repo_name,
|
| 273 |
+
oauth_token=oauth_token,
|
| 274 |
+
private=private,
|
| 275 |
+
pipeline_code=pipeline_code,
|
| 276 |
)
|
| 277 |
try:
|
| 278 |
progress(0.1, desc="Setting up user and workspace")
|
|
|
|
| 529 |
label="Distilabel Pipeline Code",
|
| 530 |
)
|
| 531 |
|
| 532 |
+
load_btn.click(
|
| 533 |
+
fn=generate_system_prompt,
|
| 534 |
+
inputs=[dataset_description],
|
| 535 |
+
outputs=[system_prompt],
|
| 536 |
+
show_progress=True,
|
| 537 |
+
).then(
|
| 538 |
+
fn=generate_sample_dataset,
|
| 539 |
+
inputs=[system_prompt, num_turns],
|
| 540 |
+
outputs=[dataframe],
|
| 541 |
+
show_progress=True,
|
| 542 |
+
)
|
| 543 |
|
| 544 |
+
btn_apply_to_sample_dataset.click(
|
| 545 |
+
fn=generate_sample_dataset,
|
| 546 |
+
inputs=[system_prompt, num_turns],
|
| 547 |
+
outputs=[dataframe],
|
| 548 |
+
show_progress=True,
|
| 549 |
+
)
|
| 550 |
|
| 551 |
+
btn_push_to_hub.click(
|
| 552 |
+
fn=validate_argilla_user_workspace_dataset,
|
| 553 |
+
inputs=[repo_name],
|
| 554 |
+
outputs=[success_message],
|
| 555 |
+
show_progress=True,
|
| 556 |
+
).then(
|
| 557 |
+
fn=validate_push_to_hub,
|
| 558 |
+
inputs=[org_name, repo_name],
|
| 559 |
+
outputs=[success_message],
|
| 560 |
+
show_progress=True,
|
| 561 |
+
).success(
|
| 562 |
+
fn=hide_success_message,
|
| 563 |
+
outputs=[success_message],
|
| 564 |
+
show_progress=True,
|
| 565 |
+
).success(
|
| 566 |
+
fn=hide_pipeline_code_visibility,
|
| 567 |
+
inputs=[],
|
| 568 |
+
outputs=[pipeline_code_ui],
|
| 569 |
+
show_progress=True,
|
| 570 |
+
).success(
|
| 571 |
+
fn=push_dataset,
|
| 572 |
+
inputs=[
|
| 573 |
+
org_name,
|
| 574 |
+
repo_name,
|
| 575 |
+
system_prompt,
|
| 576 |
+
num_turns,
|
| 577 |
+
num_rows,
|
| 578 |
+
private,
|
| 579 |
+
temperature,
|
| 580 |
+
pipeline_code,
|
| 581 |
+
],
|
| 582 |
+
outputs=[success_message],
|
| 583 |
+
show_progress=True,
|
| 584 |
+
).success(
|
| 585 |
+
fn=show_success_message,
|
| 586 |
+
inputs=[org_name, repo_name],
|
| 587 |
+
outputs=[success_message],
|
| 588 |
+
).success(
|
| 589 |
+
fn=generate_pipeline_code,
|
| 590 |
+
inputs=[system_prompt, num_turns, num_rows],
|
| 591 |
+
outputs=[pipeline_code],
|
| 592 |
+
).success(
|
| 593 |
+
fn=show_pipeline_code_visibility,
|
| 594 |
+
inputs=[],
|
| 595 |
+
outputs=[pipeline_code_ui],
|
| 596 |
+
)
|
| 597 |
+
gr.on(
|
| 598 |
+
triggers=[clear_btn_part.click, clear_btn_full.click],
|
| 599 |
+
fn=lambda _: ("", "", 1, _get_dataframe()),
|
| 600 |
+
inputs=[dataframe],
|
| 601 |
+
outputs=[dataset_description, system_prompt, num_turns, dataframe],
|
| 602 |
+
)
|
| 603 |
+
app.load(fn=get_org_dropdown, outputs=[org_name])
|
| 604 |
+
app.load(fn=get_random_repo_name, outputs=[repo_name])
|
| 605 |
+
app.load(fn=swap_visibility, outputs=main_ui)
|
src/synthetic_dataset_generator/apps/eval.py
CHANGED
|
@@ -889,6 +889,7 @@ with gr.Blocks() as app:
|
|
| 889 |
outputs=[
|
| 890 |
instruction_instruction_response,
|
| 891 |
response_instruction_response,
|
|
|
|
| 892 |
],
|
| 893 |
)
|
| 894 |
|
|
|
|
| 889 |
outputs=[
|
| 890 |
instruction_instruction_response,
|
| 891 |
response_instruction_response,
|
| 892 |
+
dataframe
|
| 893 |
],
|
| 894 |
)
|
| 895 |
|
src/synthetic_dataset_generator/apps/rag.py
CHANGED
|
@@ -76,7 +76,7 @@ def _load_dataset_from_hub(
|
|
| 76 |
progress=gr.Progress(track_tqdm=True),
|
| 77 |
):
|
| 78 |
if not repo_id:
|
| 79 |
-
raise gr.Error("Hub repo
|
| 80 |
subsets = get_dataset_config_names(repo_id, token=token)
|
| 81 |
splits = get_dataset_split_names(repo_id, subsets[0], token=token)
|
| 82 |
ds = load_dataset(repo_id, subsets[0], split=splits[0], token=token, streaming=True)
|
|
@@ -101,7 +101,10 @@ def _load_dataset_from_hub(
|
|
| 101 |
)
|
| 102 |
|
| 103 |
|
| 104 |
-
def _preprocess_input_data(file_paths, num_rows, progress=gr.Progress(track_tqdm=True)):
|
|
|
|
|
|
|
|
|
|
| 105 |
data = {}
|
| 106 |
total_chunks = 0
|
| 107 |
|
|
@@ -131,7 +134,7 @@ def _preprocess_input_data(file_paths, num_rows, progress=gr.Progress(track_tqdm
|
|
| 131 |
)
|
| 132 |
|
| 133 |
|
| 134 |
-
def generate_system_prompt(dataset_description, progress=gr.Progress()):
|
| 135 |
progress(0.1, desc="Initializing")
|
| 136 |
generate_description = get_prompt_generator()
|
| 137 |
progress(0.5, desc="Generating")
|
|
@@ -753,7 +756,6 @@ with gr.Blocks() as app:
|
|
| 753 |
) as pipeline_code_ui:
|
| 754 |
code = generate_pipeline_code(
|
| 755 |
repo_id=search_in.value,
|
| 756 |
-
file_paths=file_in.value,
|
| 757 |
input_type=input_type.value,
|
| 758 |
system_prompt=system_prompt.value,
|
| 759 |
document_column=document_column.value,
|
|
@@ -891,7 +893,6 @@ with gr.Blocks() as app:
|
|
| 891 |
fn=generate_pipeline_code,
|
| 892 |
inputs=[
|
| 893 |
search_in,
|
| 894 |
-
file_in,
|
| 895 |
input_type,
|
| 896 |
system_prompt,
|
| 897 |
document_column,
|
|
|
|
| 76 |
progress=gr.Progress(track_tqdm=True),
|
| 77 |
):
|
| 78 |
if not repo_id:
|
| 79 |
+
raise gr.Error("Please provide a Hub repo ID")
|
| 80 |
subsets = get_dataset_config_names(repo_id, token=token)
|
| 81 |
splits = get_dataset_split_names(repo_id, subsets[0], token=token)
|
| 82 |
ds = load_dataset(repo_id, subsets[0], split=splits[0], token=token, streaming=True)
|
|
|
|
| 101 |
)
|
| 102 |
|
| 103 |
|
| 104 |
+
def _preprocess_input_data(file_paths: list[str], num_rows: int, progress=gr.Progress(track_tqdm=True)):
|
| 105 |
+
if not file_paths:
|
| 106 |
+
raise gr.Error("Please provide an input file")
|
| 107 |
+
|
| 108 |
data = {}
|
| 109 |
total_chunks = 0
|
| 110 |
|
|
|
|
| 134 |
)
|
| 135 |
|
| 136 |
|
| 137 |
+
def generate_system_prompt(dataset_description: str, progress=gr.Progress()):
|
| 138 |
progress(0.1, desc="Initializing")
|
| 139 |
generate_description = get_prompt_generator()
|
| 140 |
progress(0.5, desc="Generating")
|
|
|
|
| 756 |
) as pipeline_code_ui:
|
| 757 |
code = generate_pipeline_code(
|
| 758 |
repo_id=search_in.value,
|
|
|
|
| 759 |
input_type=input_type.value,
|
| 760 |
system_prompt=system_prompt.value,
|
| 761 |
document_column=document_column.value,
|
|
|
|
| 893 |
fn=generate_pipeline_code,
|
| 894 |
inputs=[
|
| 895 |
search_in,
|
|
|
|
| 896 |
input_type,
|
| 897 |
system_prompt,
|
| 898 |
document_column,
|
src/synthetic_dataset_generator/apps/textcat.py
CHANGED
|
@@ -49,7 +49,7 @@ def _get_dataframe():
|
|
| 49 |
)
|
| 50 |
|
| 51 |
|
| 52 |
-
def generate_system_prompt(dataset_description, progress=gr.Progress()):
|
| 53 |
progress(0.0, desc="Starting")
|
| 54 |
progress(0.3, desc="Initializing")
|
| 55 |
generate_description = get_prompt_generator()
|
|
@@ -71,7 +71,12 @@ def generate_system_prompt(dataset_description, progress=gr.Progress()):
|
|
| 71 |
|
| 72 |
|
| 73 |
def generate_sample_dataset(
|
| 74 |
-
system_prompt
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
):
|
| 76 |
dataframe = generate_dataset(
|
| 77 |
system_prompt=system_prompt,
|
|
@@ -294,14 +299,14 @@ def push_dataset(
|
|
| 294 |
temperature=temperature,
|
| 295 |
)
|
| 296 |
push_dataset_to_hub(
|
| 297 |
-
dataframe,
|
| 298 |
-
org_name,
|
| 299 |
-
repo_name,
|
| 300 |
-
multi_label,
|
| 301 |
-
labels,
|
| 302 |
-
oauth_token,
|
| 303 |
-
private,
|
| 304 |
-
pipeline_code,
|
| 305 |
)
|
| 306 |
|
| 307 |
dataframe = dataframe[
|
|
@@ -657,6 +662,7 @@ with gr.Blocks() as app:
|
|
| 657 |
"",
|
| 658 |
"",
|
| 659 |
[],
|
|
|
|
| 660 |
_get_dataframe(),
|
| 661 |
),
|
| 662 |
inputs=[dataframe],
|
|
|
|
| 49 |
)
|
| 50 |
|
| 51 |
|
| 52 |
+
def generate_system_prompt(dataset_description: str, progress=gr.Progress()):
|
| 53 |
progress(0.0, desc="Starting")
|
| 54 |
progress(0.3, desc="Initializing")
|
| 55 |
generate_description = get_prompt_generator()
|
|
|
|
| 71 |
|
| 72 |
|
| 73 |
def generate_sample_dataset(
|
| 74 |
+
system_prompt: str,
|
| 75 |
+
difficulty: str,
|
| 76 |
+
clarity: str,
|
| 77 |
+
labels: List[str],
|
| 78 |
+
multi_label: bool,
|
| 79 |
+
progress=gr.Progress(),
|
| 80 |
):
|
| 81 |
dataframe = generate_dataset(
|
| 82 |
system_prompt=system_prompt,
|
|
|
|
| 299 |
temperature=temperature,
|
| 300 |
)
|
| 301 |
push_dataset_to_hub(
|
| 302 |
+
dataframe=dataframe,
|
| 303 |
+
org_name=org_name,
|
| 304 |
+
repo_name=repo_name,
|
| 305 |
+
multi_label=multi_label,
|
| 306 |
+
labels=labels,
|
| 307 |
+
oauth_token=oauth_token,
|
| 308 |
+
private=private,
|
| 309 |
+
pipeline_code=pipeline_code,
|
| 310 |
)
|
| 311 |
|
| 312 |
dataframe = dataframe[
|
|
|
|
| 662 |
"",
|
| 663 |
"",
|
| 664 |
[],
|
| 665 |
+
"",
|
| 666 |
_get_dataframe(),
|
| 667 |
),
|
| 668 |
inputs=[dataframe],
|
src/synthetic_dataset_generator/pipelines/chat.py
CHANGED
|
@@ -140,7 +140,7 @@ else:
|
|
| 140 |
]
|
| 141 |
|
| 142 |
|
| 143 |
-
def _get_output_mappings(num_turns):
|
| 144 |
if num_turns == 1:
|
| 145 |
return {"instruction": "prompt", "response": "completion"}
|
| 146 |
else:
|
|
@@ -162,7 +162,7 @@ def get_prompt_generator():
|
|
| 162 |
return prompt_generator
|
| 163 |
|
| 164 |
|
| 165 |
-
def get_magpie_generator(
|
| 166 |
input_mappings = _get_output_mappings(num_turns)
|
| 167 |
output_mappings = input_mappings.copy()
|
| 168 |
if num_turns == 1:
|
|
@@ -203,7 +203,9 @@ def get_magpie_generator(system_prompt, num_turns, temperature, is_sample):
|
|
| 203 |
return magpie_generator
|
| 204 |
|
| 205 |
|
| 206 |
-
def get_response_generator(
|
|
|
|
|
|
|
| 207 |
if num_turns == 1:
|
| 208 |
generation_kwargs = {
|
| 209 |
"temperature": temperature,
|
|
@@ -229,7 +231,7 @@ def get_response_generator(system_prompt, num_turns, temperature, is_sample):
|
|
| 229 |
return response_generator
|
| 230 |
|
| 231 |
|
| 232 |
-
def generate_pipeline_code(system_prompt, num_turns, num_rows):
|
| 233 |
input_mappings = _get_output_mappings(num_turns)
|
| 234 |
|
| 235 |
code = f"""
|
|
|
|
| 140 |
]
|
| 141 |
|
| 142 |
|
| 143 |
+
def _get_output_mappings(num_turns: int):
|
| 144 |
if num_turns == 1:
|
| 145 |
return {"instruction": "prompt", "response": "completion"}
|
| 146 |
else:
|
|
|
|
| 162 |
return prompt_generator
|
| 163 |
|
| 164 |
|
| 165 |
+
def get_magpie_generator(num_turns: int, temperature: float, is_sample: bool):
|
| 166 |
input_mappings = _get_output_mappings(num_turns)
|
| 167 |
output_mappings = input_mappings.copy()
|
| 168 |
if num_turns == 1:
|
|
|
|
| 203 |
return magpie_generator
|
| 204 |
|
| 205 |
|
| 206 |
+
def get_response_generator(
|
| 207 |
+
system_prompt: str, num_turns: int, temperature: float, is_sample: bool
|
| 208 |
+
):
|
| 209 |
if num_turns == 1:
|
| 210 |
generation_kwargs = {
|
| 211 |
"temperature": temperature,
|
|
|
|
| 231 |
return response_generator
|
| 232 |
|
| 233 |
|
| 234 |
+
def generate_pipeline_code(system_prompt: str, num_turns: int, num_rows: int):
|
| 235 |
input_mappings = _get_output_mappings(num_turns)
|
| 236 |
|
| 237 |
code = f"""
|
src/synthetic_dataset_generator/pipelines/eval.py
CHANGED
|
@@ -1,3 +1,5 @@
|
|
|
|
|
|
|
|
| 1 |
from datasets import get_dataset_config_names, get_dataset_split_names
|
| 2 |
from distilabel.models import InferenceEndpointsLLM
|
| 3 |
from distilabel.steps.tasks import (
|
|
@@ -10,7 +12,7 @@ from synthetic_dataset_generator.pipelines.base import _get_next_api_key
|
|
| 10 |
from synthetic_dataset_generator.utils import extract_column_names
|
| 11 |
|
| 12 |
|
| 13 |
-
def get_ultrafeedback_evaluator(aspect, is_sample):
|
| 14 |
ultrafeedback_evaluator = UltraFeedback(
|
| 15 |
llm=InferenceEndpointsLLM(
|
| 16 |
model_id=MODEL,
|
|
@@ -27,7 +29,9 @@ def get_ultrafeedback_evaluator(aspect, is_sample):
|
|
| 27 |
return ultrafeedback_evaluator
|
| 28 |
|
| 29 |
|
| 30 |
-
def get_custom_evaluator(
|
|
|
|
|
|
|
| 31 |
custom_evaluator = TextGeneration(
|
| 32 |
llm=InferenceEndpointsLLM(
|
| 33 |
model_id=MODEL,
|
|
@@ -47,7 +51,13 @@ def get_custom_evaluator(prompt_template, structured_output, columns, is_sample)
|
|
| 47 |
|
| 48 |
|
| 49 |
def generate_ultrafeedback_pipeline_code(
|
| 50 |
-
repo_id
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
):
|
| 52 |
if len(aspects) == 1:
|
| 53 |
code = f"""
|
|
|
|
| 1 |
+
from typing import List
|
| 2 |
+
|
| 3 |
from datasets import get_dataset_config_names, get_dataset_split_names
|
| 4 |
from distilabel.models import InferenceEndpointsLLM
|
| 5 |
from distilabel.steps.tasks import (
|
|
|
|
| 12 |
from synthetic_dataset_generator.utils import extract_column_names
|
| 13 |
|
| 14 |
|
| 15 |
+
def get_ultrafeedback_evaluator(aspect: str, is_sample: bool):
|
| 16 |
ultrafeedback_evaluator = UltraFeedback(
|
| 17 |
llm=InferenceEndpointsLLM(
|
| 18 |
model_id=MODEL,
|
|
|
|
| 29 |
return ultrafeedback_evaluator
|
| 30 |
|
| 31 |
|
| 32 |
+
def get_custom_evaluator(
|
| 33 |
+
prompt_template: str, structured_output: dict, columns: List[str], is_sample: bool
|
| 34 |
+
):
|
| 35 |
custom_evaluator = TextGeneration(
|
| 36 |
llm=InferenceEndpointsLLM(
|
| 37 |
model_id=MODEL,
|
|
|
|
| 51 |
|
| 52 |
|
| 53 |
def generate_ultrafeedback_pipeline_code(
|
| 54 |
+
repo_id: str,
|
| 55 |
+
subset: str,
|
| 56 |
+
split: str,
|
| 57 |
+
aspects: List[str],
|
| 58 |
+
instruction_column: str,
|
| 59 |
+
response_columns: str,
|
| 60 |
+
num_rows: int,
|
| 61 |
):
|
| 62 |
if len(aspects) == 1:
|
| 63 |
code = f"""
|
src/synthetic_dataset_generator/pipelines/rag.py
CHANGED
|
@@ -87,7 +87,7 @@ def get_prompt_generator():
|
|
| 87 |
return text_generator
|
| 88 |
|
| 89 |
|
| 90 |
-
def get_chunks_generator(temperature, is_sample):
|
| 91 |
generation_kwargs = {
|
| 92 |
"temperature": temperature,
|
| 93 |
"max_new_tokens": MAX_NUM_TOKENS if is_sample else 256,
|
|
@@ -104,7 +104,7 @@ def get_chunks_generator(temperature, is_sample):
|
|
| 104 |
return text_generator
|
| 105 |
|
| 106 |
|
| 107 |
-
def get_sentence_pair_generator(action, triplet, temperature, is_sample):
|
| 108 |
generation_kwargs = {
|
| 109 |
"temperature": temperature,
|
| 110 |
"max_new_tokens": 256 if is_sample else MAX_NUM_TOKENS,
|
|
@@ -119,7 +119,7 @@ def get_sentence_pair_generator(action, triplet, temperature, is_sample):
|
|
| 119 |
return sentence_pair_generator
|
| 120 |
|
| 121 |
|
| 122 |
-
def get_response_generator(temperature, is_sample):
|
| 123 |
generation_kwargs = {
|
| 124 |
"temperature": temperature,
|
| 125 |
"max_new_tokens": MAX_NUM_TOKENS if is_sample else 256,
|
|
@@ -138,7 +138,6 @@ def get_response_generator(temperature, is_sample):
|
|
| 138 |
|
| 139 |
def generate_pipeline_code(
|
| 140 |
repo_id: str,
|
| 141 |
-
file_paths: List[str],
|
| 142 |
input_type: str,
|
| 143 |
system_prompt: str,
|
| 144 |
document_column: str,
|
|
|
|
| 87 |
return text_generator
|
| 88 |
|
| 89 |
|
| 90 |
+
def get_chunks_generator(temperature: float, is_sample: bool):
|
| 91 |
generation_kwargs = {
|
| 92 |
"temperature": temperature,
|
| 93 |
"max_new_tokens": MAX_NUM_TOKENS if is_sample else 256,
|
|
|
|
| 104 |
return text_generator
|
| 105 |
|
| 106 |
|
| 107 |
+
def get_sentence_pair_generator(action: str, triplet: bool, temperature: float, is_sample: bool):
|
| 108 |
generation_kwargs = {
|
| 109 |
"temperature": temperature,
|
| 110 |
"max_new_tokens": 256 if is_sample else MAX_NUM_TOKENS,
|
|
|
|
| 119 |
return sentence_pair_generator
|
| 120 |
|
| 121 |
|
| 122 |
+
def get_response_generator(temperature: float, is_sample: bool):
|
| 123 |
generation_kwargs = {
|
| 124 |
"temperature": temperature,
|
| 125 |
"max_new_tokens": MAX_NUM_TOKENS if is_sample else 256,
|
|
|
|
| 138 |
|
| 139 |
def generate_pipeline_code(
|
| 140 |
repo_id: str,
|
|
|
|
| 141 |
input_type: str,
|
| 142 |
system_prompt: str,
|
| 143 |
document_column: str,
|
src/synthetic_dataset_generator/pipelines/textcat.py
CHANGED
|
@@ -85,7 +85,9 @@ def get_prompt_generator():
|
|
| 85 |
return prompt_generator
|
| 86 |
|
| 87 |
|
| 88 |
-
def get_textcat_generator(
|
|
|
|
|
|
|
| 89 |
generation_kwargs = {
|
| 90 |
"temperature": temperature,
|
| 91 |
"max_new_tokens": 256 if is_sample else MAX_NUM_TOKENS,
|
|
@@ -102,7 +104,7 @@ def get_textcat_generator(difficulty, clarity, temperature, is_sample):
|
|
| 102 |
return textcat_generator
|
| 103 |
|
| 104 |
|
| 105 |
-
def get_labeller_generator(system_prompt, labels, multi_label):
|
| 106 |
generation_kwargs = {
|
| 107 |
"temperature": 0.01,
|
| 108 |
"max_new_tokens": MAX_NUM_TOKENS,
|
|
|
|
| 85 |
return prompt_generator
|
| 86 |
|
| 87 |
|
| 88 |
+
def get_textcat_generator(
|
| 89 |
+
difficulty: str, clarity: str, temperature: float, is_sample: bool
|
| 90 |
+
):
|
| 91 |
generation_kwargs = {
|
| 92 |
"temperature": temperature,
|
| 93 |
"max_new_tokens": 256 if is_sample else MAX_NUM_TOKENS,
|
|
|
|
| 104 |
return textcat_generator
|
| 105 |
|
| 106 |
|
| 107 |
+
def get_labeller_generator(system_prompt: str, labels: List[str], multi_label: bool):
|
| 108 |
generation_kwargs = {
|
| 109 |
"temperature": 0.01,
|
| 110 |
"max_new_tokens": MAX_NUM_TOKENS,
|