Spaces:
Runtime error
Runtime error
sebaxakerhtc
commited on
Added local saving to CSV and JSON (#38)
Browse files* Local save
- Added save local function to chat tab (CSV, JSON)
- Rebuild UI with new feature
- CSS edit for gr.File (perfectionism)
* Local save
* Mistake
* Update chat.py
* Local save RAG and Textcat
* Rebuild UI
* Show save_local only if save_local_dir is provided
- src/synthetic_dataset_generator/app.py +5 -4
- src/synthetic_dataset_generator/apps/base.py +5 -1
- src/synthetic_dataset_generator/apps/chat.py +73 -2
- src/synthetic_dataset_generator/apps/rag.py +79 -2
- src/synthetic_dataset_generator/apps/textcat.py +62 -2
- src/synthetic_dataset_generator/constants.py +3 -0
src/synthetic_dataset_generator/app.py
CHANGED
@@ -12,12 +12,13 @@ css = """
|
|
12 |
.main_ui_logged_out{opacity: 0.3; pointer-events: none}
|
13 |
button[role="tab"][aria-selected="true"] { border: 0; background: var(--button-primary-background-fill); color: white; border-top-right-radius: var(--radius-md); border-top-left-radius: var(--radius-md)}
|
14 |
button[role="tab"][aria-selected="true"]:hover {border-color: var(--button-primary-background-fill); background: var(var(--button-primary-background-fill-hover))}
|
15 |
-
.tabitem {
|
16 |
.gallery-item {background: var(--background-fill-secondary); text-align: left}
|
17 |
-
.table-wrap .tbody td {
|
18 |
-
#system_prompt_examples {
|
19 |
.container {padding-inline: 0 !important}
|
20 |
-
#sign_in_button {
|
|
|
21 |
"""
|
22 |
|
23 |
image = """<br><img src="https://raw.githubusercontent.com/argilla-io/synthetic-data-generator/main/assets/logo.svg" alt="Synthetic Data Generator Logo" style="display: block; margin-left: auto; margin-right: auto; width: clamp(50%, 400px, 100%)"/>"""
|
|
|
12 |
.main_ui_logged_out{opacity: 0.3; pointer-events: none}
|
13 |
button[role="tab"][aria-selected="true"] { border: 0; background: var(--button-primary-background-fill); color: white; border-top-right-radius: var(--radius-md); border-top-left-radius: var(--radius-md)}
|
14 |
button[role="tab"][aria-selected="true"]:hover {border-color: var(--button-primary-background-fill); background: var(var(--button-primary-background-fill-hover))}
|
15 |
+
.tabitem {border: 0; padding-inline: 0}
|
16 |
.gallery-item {background: var(--background-fill-secondary); text-align: left}
|
17 |
+
.table-wrap .tbody td {vertical-align: top}
|
18 |
+
#system_prompt_examples {color: var(--body-text-color) !important; background-color: var(--block-background-fill) !important;}
|
19 |
.container {padding-inline: 0 !important}
|
20 |
+
#sign_in_button {flex-grow: 0; width: auto !important; display: flex; align-items: center; justify-content: center; margin: 0 auto;}
|
21 |
+
.datasets {height: 70px;}
|
22 |
"""
|
23 |
|
24 |
image = """<br><img src="https://raw.githubusercontent.com/argilla-io/synthetic-data-generator/main/assets/logo.svg" alt="Synthetic Data Generator Logo" style="display: block; margin-left: auto; margin-right: auto; width: clamp(50%, 400px, 100%)"/>"""
|
src/synthetic_dataset_generator/apps/base.py
CHANGED
@@ -12,9 +12,13 @@ from huggingface_hub import HfApi, upload_file, repo_exists
|
|
12 |
from unstructured.chunking.title import chunk_by_title
|
13 |
from unstructured.partition.auto import partition
|
14 |
|
15 |
-
from synthetic_dataset_generator.constants import MAX_NUM_ROWS
|
16 |
from synthetic_dataset_generator.utils import get_argilla_client
|
17 |
|
|
|
|
|
|
|
|
|
18 |
|
19 |
def validate_argilla_user_workspace_dataset(
|
20 |
dataset_name: str,
|
|
|
12 |
from unstructured.chunking.title import chunk_by_title
|
13 |
from unstructured.partition.auto import partition
|
14 |
|
15 |
+
from synthetic_dataset_generator.constants import MAX_NUM_ROWS, SAVE_LOCAL_DIR
|
16 |
from synthetic_dataset_generator.utils import get_argilla_client
|
17 |
|
18 |
+
if SAVE_LOCAL_DIR is not None:
|
19 |
+
import os
|
20 |
+
os.makedirs(SAVE_LOCAL_DIR, exist_ok=True)
|
21 |
+
|
22 |
|
23 |
def validate_argilla_user_workspace_dataset(
|
24 |
dataset_name: str,
|
src/synthetic_dataset_generator/apps/chat.py
CHANGED
@@ -2,6 +2,7 @@ import ast
|
|
2 |
import json
|
3 |
import random
|
4 |
import uuid
|
|
|
5 |
from typing import Dict, List, Union
|
6 |
|
7 |
import argilla as rg
|
@@ -30,6 +31,7 @@ from synthetic_dataset_generator.constants import (
|
|
30 |
MODEL,
|
31 |
MODEL_COMPLETION,
|
32 |
SFT_AVAILABLE,
|
|
|
33 |
)
|
34 |
from synthetic_dataset_generator.pipelines.base import get_rewritten_prompts
|
35 |
from synthetic_dataset_generator.pipelines.chat import (
|
@@ -264,7 +266,6 @@ def generate_dataset_from_prompt(
|
|
264 |
progress(1.0, desc="Dataset generation completed")
|
265 |
return dataframe
|
266 |
|
267 |
-
|
268 |
def generate_dataset_from_seed(
|
269 |
dataframe: pd.DataFrame,
|
270 |
document_column: str,
|
@@ -506,7 +507,7 @@ def push_dataset(
|
|
506 |
num_turns=num_turns,
|
507 |
num_rows=num_rows,
|
508 |
temperature=temperature,
|
509 |
-
temperature_completion=temperature_completion
|
510 |
)
|
511 |
push_dataset_to_hub(
|
512 |
dataframe=dataframe,
|
@@ -637,6 +638,45 @@ def push_dataset(
|
|
637 |
return ""
|
638 |
|
639 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
640 |
def show_system_prompt_visibility():
|
641 |
return {system_prompt: gr.Textbox(visible=True)}
|
642 |
|
@@ -670,6 +710,13 @@ def hide_pipeline_code_visibility():
|
|
670 |
def show_temperature_completion():
|
671 |
if MODEL != MODEL_COMPLETION:
|
672 |
return {temperature_completion: gr.Slider(value=0.9, visible=True)}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
673 |
|
674 |
|
675 |
######################
|
@@ -852,6 +899,11 @@ with gr.Blocks() as app:
|
|
852 |
btn_push_to_hub = gr.Button(
|
853 |
"Push to Hub", variant="primary", scale=2
|
854 |
)
|
|
|
|
|
|
|
|
|
|
|
855 |
with gr.Column(scale=3):
|
856 |
success_message = gr.Markdown(
|
857 |
visible=True,
|
@@ -998,6 +1050,23 @@ with gr.Blocks() as app:
|
|
998 |
inputs=[],
|
999 |
outputs=[pipeline_code_ui],
|
1000 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1001 |
|
1002 |
clear_dataset_btn_part.click(fn=lambda: "", inputs=[], outputs=[search_in])
|
1003 |
clear_file_btn_part.click(fn=lambda: None, inputs=[], outputs=[file_in])
|
@@ -1011,3 +1080,5 @@ with gr.Blocks() as app:
|
|
1011 |
app.load(fn=get_org_dropdown, outputs=[org_name])
|
1012 |
app.load(fn=get_random_repo_name, outputs=[repo_name])
|
1013 |
app.load(fn=show_temperature_completion, outputs=[temperature_completion])
|
|
|
|
|
|
2 |
import json
|
3 |
import random
|
4 |
import uuid
|
5 |
+
import os
|
6 |
from typing import Dict, List, Union
|
7 |
|
8 |
import argilla as rg
|
|
|
31 |
MODEL,
|
32 |
MODEL_COMPLETION,
|
33 |
SFT_AVAILABLE,
|
34 |
+
SAVE_LOCAL_DIR,
|
35 |
)
|
36 |
from synthetic_dataset_generator.pipelines.base import get_rewritten_prompts
|
37 |
from synthetic_dataset_generator.pipelines.chat import (
|
|
|
266 |
progress(1.0, desc="Dataset generation completed")
|
267 |
return dataframe
|
268 |
|
|
|
269 |
def generate_dataset_from_seed(
|
270 |
dataframe: pd.DataFrame,
|
271 |
document_column: str,
|
|
|
507 |
num_turns=num_turns,
|
508 |
num_rows=num_rows,
|
509 |
temperature=temperature,
|
510 |
+
temperature_completion=temperature_completion,
|
511 |
)
|
512 |
push_dataset_to_hub(
|
513 |
dataframe=dataframe,
|
|
|
638 |
return ""
|
639 |
|
640 |
|
641 |
+
def save_local(
|
642 |
+
repo_id: str,
|
643 |
+
file_paths: list[str],
|
644 |
+
input_type: str,
|
645 |
+
system_prompt: str,
|
646 |
+
document_column: str,
|
647 |
+
num_turns: int,
|
648 |
+
num_rows: int,
|
649 |
+
temperature: float,
|
650 |
+
repo_name: str,
|
651 |
+
temperature_completion: Union[float, None] = None,
|
652 |
+
) -> pd.DataFrame:
|
653 |
+
if input_type == "prompt-input":
|
654 |
+
dataframe = _get_dataframe()
|
655 |
+
else:
|
656 |
+
dataframe, _ = load_dataset_file(
|
657 |
+
repo_id=repo_id,
|
658 |
+
file_paths=file_paths,
|
659 |
+
input_type=input_type,
|
660 |
+
num_rows=num_rows,
|
661 |
+
)
|
662 |
+
dataframe = generate_dataset(
|
663 |
+
input_type=input_type,
|
664 |
+
dataframe=dataframe,
|
665 |
+
system_prompt=system_prompt,
|
666 |
+
document_column=document_column,
|
667 |
+
num_turns=num_turns,
|
668 |
+
num_rows=num_rows,
|
669 |
+
temperature=temperature,
|
670 |
+
temperature_completion=temperature_completion
|
671 |
+
)
|
672 |
+
local_dataset = Dataset.from_pandas(dataframe)
|
673 |
+
output_csv = os.path.join(SAVE_LOCAL_DIR, repo_name + ".csv")
|
674 |
+
output_json = os.path.join(SAVE_LOCAL_DIR, repo_name + ".json")
|
675 |
+
local_dataset.to_csv(output_csv, index=False)
|
676 |
+
local_dataset.to_json(output_json, index=False)
|
677 |
+
return output_csv, output_json
|
678 |
+
|
679 |
+
|
680 |
def show_system_prompt_visibility():
|
681 |
return {system_prompt: gr.Textbox(visible=True)}
|
682 |
|
|
|
710 |
def show_temperature_completion():
|
711 |
if MODEL != MODEL_COMPLETION:
|
712 |
return {temperature_completion: gr.Slider(value=0.9, visible=True)}
|
713 |
+
|
714 |
+
def show_save_local():
|
715 |
+
return {
|
716 |
+
btn_save_local: gr.Button(visible=True),
|
717 |
+
csv_file: gr.File(visible=True),
|
718 |
+
json_file: gr.File(visible=True)
|
719 |
+
}
|
720 |
|
721 |
|
722 |
######################
|
|
|
899 |
btn_push_to_hub = gr.Button(
|
900 |
"Push to Hub", variant="primary", scale=2
|
901 |
)
|
902 |
+
btn_save_local = gr.Button(
|
903 |
+
"Save locally", variant="primary", scale=2, visible=False
|
904 |
+
)
|
905 |
+
csv_file = gr.File(label="CSV", elem_classes="datasets", visible=False)
|
906 |
+
json_file = gr.File(label="JSON", elem_classes="datasets", visible=False)
|
907 |
with gr.Column(scale=3):
|
908 |
success_message = gr.Markdown(
|
909 |
visible=True,
|
|
|
1050 |
inputs=[],
|
1051 |
outputs=[pipeline_code_ui],
|
1052 |
)
|
1053 |
+
|
1054 |
+
btn_save_local.click(
|
1055 |
+
save_local,
|
1056 |
+
inputs=[
|
1057 |
+
search_in,
|
1058 |
+
file_in,
|
1059 |
+
input_type,
|
1060 |
+
system_prompt,
|
1061 |
+
document_column,
|
1062 |
+
num_turns,
|
1063 |
+
num_rows,
|
1064 |
+
temperature,
|
1065 |
+
repo_name,
|
1066 |
+
temperature_completion,
|
1067 |
+
],
|
1068 |
+
outputs=[csv_file, json_file]
|
1069 |
+
)
|
1070 |
|
1071 |
clear_dataset_btn_part.click(fn=lambda: "", inputs=[], outputs=[search_in])
|
1072 |
clear_file_btn_part.click(fn=lambda: None, inputs=[], outputs=[file_in])
|
|
|
1080 |
app.load(fn=get_org_dropdown, outputs=[org_name])
|
1081 |
app.load(fn=get_random_repo_name, outputs=[repo_name])
|
1082 |
app.load(fn=show_temperature_completion, outputs=[temperature_completion])
|
1083 |
+
if SAVE_LOCAL_DIR is not None:
|
1084 |
+
app.load(fn=show_save_local, outputs=[btn_save_local, csv_file, json_file])
|
src/synthetic_dataset_generator/apps/rag.py
CHANGED
@@ -24,7 +24,7 @@ from synthetic_dataset_generator.apps.base import (
|
|
24 |
validate_argilla_user_workspace_dataset,
|
25 |
validate_push_to_hub,
|
26 |
)
|
27 |
-
from synthetic_dataset_generator.constants import DEFAULT_BATCH_SIZE, MODEL, MODEL_COMPLETION
|
28 |
from synthetic_dataset_generator.pipelines.base import get_rewritten_prompts
|
29 |
from synthetic_dataset_generator.pipelines.embeddings import (
|
30 |
get_embeddings,
|
@@ -486,6 +486,49 @@ def push_dataset(
|
|
486 |
return ""
|
487 |
|
488 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
489 |
def show_system_prompt_visibility():
|
490 |
return {system_prompt: gr.Textbox(visible=True)}
|
491 |
|
@@ -521,6 +564,14 @@ def show_temperature_completion():
|
|
521 |
return {temperature_completion: gr.Slider(value=0.9, visible=True)}
|
522 |
|
523 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
524 |
######################
|
525 |
# Gradio UI
|
526 |
######################
|
@@ -674,7 +725,14 @@ with gr.Blocks() as app:
|
|
674 |
interactive=True,
|
675 |
scale=1,
|
676 |
)
|
677 |
-
btn_push_to_hub = gr.Button(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
678 |
with gr.Column(scale=3):
|
679 |
success_message = gr.Markdown(
|
680 |
visible=True,
|
@@ -822,6 +880,23 @@ with gr.Blocks() as app:
|
|
822 |
outputs=[pipeline_code_ui],
|
823 |
)
|
824 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
825 |
clear_dataset_btn_part.click(fn=lambda: "", inputs=[], outputs=[search_in])
|
826 |
clear_file_btn_part.click(fn=lambda: None, inputs=[], outputs=[file_in])
|
827 |
clear_prompt_btn_part.click(fn=lambda: "", inputs=[], outputs=[dataset_description])
|
@@ -835,3 +910,5 @@ with gr.Blocks() as app:
|
|
835 |
app.load(fn=get_org_dropdown, outputs=[org_name])
|
836 |
app.load(fn=get_random_repo_name, outputs=[repo_name])
|
837 |
app.load(fn=show_temperature_completion, outputs=[temperature_completion])
|
|
|
|
|
|
24 |
validate_argilla_user_workspace_dataset,
|
25 |
validate_push_to_hub,
|
26 |
)
|
27 |
+
from synthetic_dataset_generator.constants import DEFAULT_BATCH_SIZE, MODEL, MODEL_COMPLETION, SAVE_LOCAL_DIR
|
28 |
from synthetic_dataset_generator.pipelines.base import get_rewritten_prompts
|
29 |
from synthetic_dataset_generator.pipelines.embeddings import (
|
30 |
get_embeddings,
|
|
|
486 |
return ""
|
487 |
|
488 |
|
489 |
+
def save_local(
|
490 |
+
repo_id: str,
|
491 |
+
file_paths: list[str],
|
492 |
+
input_type: str,
|
493 |
+
system_prompt: str,
|
494 |
+
document_column: str,
|
495 |
+
retrieval_reranking: list[str],
|
496 |
+
num_rows: int,
|
497 |
+
temperature: float,
|
498 |
+
repo_name: str,
|
499 |
+
temperature_completion: float,
|
500 |
+
) -> pd.DataFrame:
|
501 |
+
retrieval = "Retrieval" in retrieval_reranking
|
502 |
+
reranking = "Reranking" in retrieval_reranking
|
503 |
+
|
504 |
+
if input_type == "prompt-input":
|
505 |
+
dataframe = pd.DataFrame(columns=["context", "question", "response"])
|
506 |
+
else:
|
507 |
+
dataframe, _ = load_dataset_file(
|
508 |
+
repo_id=repo_id,
|
509 |
+
file_paths=file_paths,
|
510 |
+
input_type=input_type,
|
511 |
+
num_rows=num_rows,
|
512 |
+
)
|
513 |
+
dataframe = generate_dataset(
|
514 |
+
input_type=input_type,
|
515 |
+
dataframe=dataframe,
|
516 |
+
system_prompt=system_prompt,
|
517 |
+
document_column=document_column,
|
518 |
+
retrieval=retrieval,
|
519 |
+
reranking=reranking,
|
520 |
+
num_rows=num_rows,
|
521 |
+
temperature=temperature,
|
522 |
+
temperature_completion=temperature_completion,
|
523 |
+
)
|
524 |
+
local_dataset = Dataset.from_pandas(dataframe)
|
525 |
+
output_csv = os.path.join(SAVE_LOCAL_DIR, repo_name + ".csv")
|
526 |
+
output_json = os.path.join(SAVE_LOCAL_DIR, repo_name + ".json")
|
527 |
+
local_dataset.to_csv(output_csv, index=False)
|
528 |
+
local_dataset.to_json(output_json, index=False)
|
529 |
+
return output_csv, output_json
|
530 |
+
|
531 |
+
|
532 |
def show_system_prompt_visibility():
|
533 |
return {system_prompt: gr.Textbox(visible=True)}
|
534 |
|
|
|
564 |
return {temperature_completion: gr.Slider(value=0.9, visible=True)}
|
565 |
|
566 |
|
567 |
+
def show_save_local():
|
568 |
+
return {
|
569 |
+
btn_save_local: gr.Button(visible=True),
|
570 |
+
csv_file: gr.File(visible=True),
|
571 |
+
json_file: gr.File(visible=True)
|
572 |
+
}
|
573 |
+
|
574 |
+
|
575 |
######################
|
576 |
# Gradio UI
|
577 |
######################
|
|
|
725 |
interactive=True,
|
726 |
scale=1,
|
727 |
)
|
728 |
+
btn_push_to_hub = gr.Button(
|
729 |
+
"Push to Hub", variant="primary", scale=2
|
730 |
+
)
|
731 |
+
btn_save_local = gr.Button(
|
732 |
+
"Save locally", variant="primary", scale=2, visible=False
|
733 |
+
)
|
734 |
+
csv_file = gr.File(label="CSV", elem_classes="datasets", visible=False)
|
735 |
+
json_file = gr.File(label="JSON", elem_classes="datasets", visible=False)
|
736 |
with gr.Column(scale=3):
|
737 |
success_message = gr.Markdown(
|
738 |
visible=True,
|
|
|
880 |
outputs=[pipeline_code_ui],
|
881 |
)
|
882 |
|
883 |
+
btn_save_local.click(
|
884 |
+
save_local,
|
885 |
+
inputs=[
|
886 |
+
search_in,
|
887 |
+
file_in,
|
888 |
+
input_type,
|
889 |
+
system_prompt,
|
890 |
+
document_column,
|
891 |
+
retrieval_reranking,
|
892 |
+
num_rows,
|
893 |
+
temperature,
|
894 |
+
repo_name,
|
895 |
+
temperature_completion,
|
896 |
+
],
|
897 |
+
outputs=[csv_file, json_file]
|
898 |
+
)
|
899 |
+
|
900 |
clear_dataset_btn_part.click(fn=lambda: "", inputs=[], outputs=[search_in])
|
901 |
clear_file_btn_part.click(fn=lambda: None, inputs=[], outputs=[file_in])
|
902 |
clear_prompt_btn_part.click(fn=lambda: "", inputs=[], outputs=[dataset_description])
|
|
|
910 |
app.load(fn=get_org_dropdown, outputs=[org_name])
|
911 |
app.load(fn=get_random_repo_name, outputs=[repo_name])
|
912 |
app.load(fn=show_temperature_completion, outputs=[temperature_completion])
|
913 |
+
if SAVE_LOCAL_DIR is not None:
|
914 |
+
app.load(fn=show_save_local, outputs=[btn_save_local, csv_file, json_file])
|
src/synthetic_dataset_generator/apps/textcat.py
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
import json
|
2 |
import random
|
3 |
import uuid
|
@@ -19,7 +20,7 @@ from synthetic_dataset_generator.apps.base import (
|
|
19 |
validate_argilla_user_workspace_dataset,
|
20 |
validate_push_to_hub,
|
21 |
)
|
22 |
-
from synthetic_dataset_generator.constants import DEFAULT_BATCH_SIZE
|
23 |
from synthetic_dataset_generator.pipelines.base import get_rewritten_prompts
|
24 |
from synthetic_dataset_generator.pipelines.embeddings import (
|
25 |
get_embeddings,
|
@@ -406,6 +407,33 @@ def push_dataset(
|
|
406 |
return ""
|
407 |
|
408 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
409 |
def validate_input_labels(labels: List[str]) -> List[str]:
|
410 |
if (
|
411 |
not labels
|
@@ -425,6 +453,14 @@ def hide_pipeline_code_visibility():
|
|
425 |
return {pipeline_code_ui: gr.Accordion(visible=False)}
|
426 |
|
427 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
428 |
######################
|
429 |
# Gradio UI
|
430 |
######################
|
@@ -543,7 +579,14 @@ with gr.Blocks() as app:
|
|
543 |
interactive=True,
|
544 |
scale=1,
|
545 |
)
|
546 |
-
btn_push_to_hub = gr.Button(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
547 |
with gr.Column(scale=3):
|
548 |
success_message = gr.Markdown(
|
549 |
visible=True,
|
@@ -643,6 +686,21 @@ with gr.Blocks() as app:
|
|
643 |
inputs=[],
|
644 |
outputs=[pipeline_code_ui],
|
645 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
646 |
|
647 |
gr.on(
|
648 |
triggers=[clear_btn_part.click, clear_btn_full.click],
|
@@ -660,3 +718,5 @@ with gr.Blocks() as app:
|
|
660 |
app.load(fn=swap_visibility, outputs=main_ui)
|
661 |
app.load(fn=get_org_dropdown, outputs=[org_name])
|
662 |
app.load(fn=get_random_repo_name, outputs=[repo_name])
|
|
|
|
|
|
1 |
+
import os
|
2 |
import json
|
3 |
import random
|
4 |
import uuid
|
|
|
20 |
validate_argilla_user_workspace_dataset,
|
21 |
validate_push_to_hub,
|
22 |
)
|
23 |
+
from synthetic_dataset_generator.constants import DEFAULT_BATCH_SIZE, SAVE_LOCAL_DIR
|
24 |
from synthetic_dataset_generator.pipelines.base import get_rewritten_prompts
|
25 |
from synthetic_dataset_generator.pipelines.embeddings import (
|
26 |
get_embeddings,
|
|
|
407 |
return ""
|
408 |
|
409 |
|
410 |
+
def save_local(
|
411 |
+
system_prompt: str,
|
412 |
+
difficulty: str,
|
413 |
+
clarity: str,
|
414 |
+
labels: List[str],
|
415 |
+
multi_label: bool,
|
416 |
+
num_rows: int,
|
417 |
+
temperature: float,
|
418 |
+
repo_name: str,
|
419 |
+
) -> pd.DataFrame:
|
420 |
+
dataframe = generate_dataset(
|
421 |
+
system_prompt=system_prompt,
|
422 |
+
difficulty=difficulty,
|
423 |
+
clarity=clarity,
|
424 |
+
multi_label=multi_label,
|
425 |
+
labels=labels,
|
426 |
+
num_rows=num_rows,
|
427 |
+
temperature=temperature,
|
428 |
+
)
|
429 |
+
local_dataset = Dataset.from_pandas(dataframe)
|
430 |
+
output_csv = os.path.join(SAVE_LOCAL_DIR, repo_name + ".csv")
|
431 |
+
output_json = os.path.join(SAVE_LOCAL_DIR, repo_name + ".json")
|
432 |
+
local_dataset.to_csv(output_csv, index=False)
|
433 |
+
local_dataset.to_json(output_json, index=False)
|
434 |
+
return output_csv, output_json
|
435 |
+
|
436 |
+
|
437 |
def validate_input_labels(labels: List[str]) -> List[str]:
|
438 |
if (
|
439 |
not labels
|
|
|
453 |
return {pipeline_code_ui: gr.Accordion(visible=False)}
|
454 |
|
455 |
|
456 |
+
def show_save_local():
|
457 |
+
return {
|
458 |
+
btn_save_local: gr.Button(visible=True),
|
459 |
+
csv_file: gr.File(visible=True),
|
460 |
+
json_file: gr.File(visible=True)
|
461 |
+
}
|
462 |
+
|
463 |
+
|
464 |
######################
|
465 |
# Gradio UI
|
466 |
######################
|
|
|
579 |
interactive=True,
|
580 |
scale=1,
|
581 |
)
|
582 |
+
btn_push_to_hub = gr.Button(
|
583 |
+
"Push to Hub", variant="primary", scale=2
|
584 |
+
)
|
585 |
+
btn_save_local = gr.Button(
|
586 |
+
"Save locally", variant="primary", scale=2, visible=False
|
587 |
+
)
|
588 |
+
csv_file = gr.File(label="CSV", elem_classes="datasets", visible=False)
|
589 |
+
json_file = gr.File(label="JSON", elem_classes="datasets", visible=False)
|
590 |
with gr.Column(scale=3):
|
591 |
success_message = gr.Markdown(
|
592 |
visible=True,
|
|
|
686 |
inputs=[],
|
687 |
outputs=[pipeline_code_ui],
|
688 |
)
|
689 |
+
|
690 |
+
btn_save_local.click(
|
691 |
+
save_local,
|
692 |
+
inputs=[
|
693 |
+
system_prompt,
|
694 |
+
difficulty,
|
695 |
+
clarity,
|
696 |
+
labels,
|
697 |
+
multi_label,
|
698 |
+
num_rows,
|
699 |
+
temperature,
|
700 |
+
repo_name,
|
701 |
+
],
|
702 |
+
outputs=[csv_file, json_file]
|
703 |
+
)
|
704 |
|
705 |
gr.on(
|
706 |
triggers=[clear_btn_part.click, clear_btn_full.click],
|
|
|
718 |
app.load(fn=swap_visibility, outputs=main_ui)
|
719 |
app.load(fn=get_org_dropdown, outputs=[org_name])
|
720 |
app.load(fn=get_random_repo_name, outputs=[repo_name])
|
721 |
+
if SAVE_LOCAL_DIR is not None:
|
722 |
+
app.load(fn=show_save_local, outputs=[btn_save_local, csv_file, json_file])
|
src/synthetic_dataset_generator/constants.py
CHANGED
@@ -8,6 +8,9 @@ MAX_NUM_TOKENS = int(os.getenv("MAX_NUM_TOKENS", 2048))
|
|
8 |
MAX_NUM_ROWS = int(os.getenv("MAX_NUM_ROWS", 1000))
|
9 |
DEFAULT_BATCH_SIZE = int(os.getenv("DEFAULT_BATCH_SIZE", 5))
|
10 |
|
|
|
|
|
|
|
11 |
# Models
|
12 |
MODEL = os.getenv("MODEL", "meta-llama/Meta-Llama-3.1-8B-Instruct")
|
13 |
TOKENIZER_ID = os.getenv(key="TOKENIZER_ID", default=None)
|
|
|
8 |
MAX_NUM_ROWS = int(os.getenv("MAX_NUM_ROWS", 1000))
|
9 |
DEFAULT_BATCH_SIZE = int(os.getenv("DEFAULT_BATCH_SIZE", 5))
|
10 |
|
11 |
+
# Directory for outputs
|
12 |
+
SAVE_LOCAL_DIR = os.getenv(key="SAVE_LOCAL_DIR", default=None)
|
13 |
+
|
14 |
# Models
|
15 |
MODEL = os.getenv("MODEL", "meta-llama/Meta-Llama-3.1-8B-Instruct")
|
16 |
TOKENIZER_ID = os.getenv(key="TOKENIZER_ID", default=None)
|