Merge pull request #6 from argilla-io/feat/improve-textcat
Browse files
src/synthetic_dataset_generator/app.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
from synthetic_dataset_generator._tabbedinterface import TabbedInterface
|
|
|
2 |
# from synthetic_dataset_generator.apps.eval import app as eval_app
|
3 |
from synthetic_dataset_generator.apps.readme import app as readme_app
|
4 |
from synthetic_dataset_generator.apps.sft import app as sft_app
|
@@ -15,9 +16,6 @@ button[role="tab"][aria-selected="true"]:hover {border-color: var(--button-prima
|
|
15 |
#system_prompt_examples { color: var(--body-text-color) !important; background-color: var(--block-background-fill) !important;}
|
16 |
.container {padding-inline: 0 !important}
|
17 |
#sign_in_button { flex-grow: 0; width: auto !important; display: flex; align-items: center; justify-content: center; margin: 0 auto; }
|
18 |
-
.table-view .table-wrap {
|
19 |
-
max-height: 450px;
|
20 |
-
}
|
21 |
"""
|
22 |
|
23 |
image = """<br><img src="https://raw.githubusercontent.com/argilla-io/synthetic-data-generator/main/assets/logo.svg" alt="Synthetic Data Generator Logo" style="display: block; margin-left: auto; margin-right: auto; width: clamp(50%, 400px, 100%)"/>"""
|
|
|
1 |
from synthetic_dataset_generator._tabbedinterface import TabbedInterface
|
2 |
+
|
3 |
# from synthetic_dataset_generator.apps.eval import app as eval_app
|
4 |
from synthetic_dataset_generator.apps.readme import app as readme_app
|
5 |
from synthetic_dataset_generator.apps.sft import app as sft_app
|
|
|
16 |
#system_prompt_examples { color: var(--body-text-color) !important; background-color: var(--block-background-fill) !important;}
|
17 |
.container {padding-inline: 0 !important}
|
18 |
#sign_in_button { flex-grow: 0; width: auto !important; display: flex; align-items: center; justify-content: center; margin: 0 auto; }
|
|
|
|
|
|
|
19 |
"""
|
20 |
|
21 |
image = """<br><img src="https://raw.githubusercontent.com/argilla-io/synthetic-data-generator/main/assets/logo.svg" alt="Synthetic Data Generator Logo" style="display: block; margin-left: auto; margin-right: auto; width: clamp(50%, 400px, 100%)"/>"""
|
src/synthetic_dataset_generator/apps/eval.py
CHANGED
@@ -750,7 +750,6 @@ with gr.Blocks() as app:
|
|
750 |
headers=["prompt", "completion", "evaluation"],
|
751 |
wrap=True,
|
752 |
interactive=False,
|
753 |
-
elem_classes="table-view",
|
754 |
)
|
755 |
|
756 |
gr.HTML(value="<hr>")
|
|
|
750 |
headers=["prompt", "completion", "evaluation"],
|
751 |
wrap=True,
|
752 |
interactive=False,
|
|
|
753 |
)
|
754 |
|
755 |
gr.HTML(value="<hr>")
|
src/synthetic_dataset_generator/apps/sft.py
CHANGED
@@ -55,10 +55,10 @@ def convert_dataframe_messages(dataframe: pd.DataFrame) -> pd.DataFrame:
|
|
55 |
|
56 |
|
57 |
def generate_system_prompt(dataset_description, progress=gr.Progress()):
|
58 |
-
progress(0.0, desc="
|
59 |
-
progress(0.3, desc="Initializing
|
60 |
generate_description = get_prompt_generator()
|
61 |
-
progress(0.7, desc="Generating
|
62 |
result = next(
|
63 |
generate_description.process(
|
64 |
[
|
@@ -68,7 +68,7 @@ def generate_system_prompt(dataset_description, progress=gr.Progress()):
|
|
68 |
]
|
69 |
)
|
70 |
)[0]["generation"]
|
71 |
-
progress(1.0, desc="
|
72 |
return result
|
73 |
|
74 |
|
@@ -88,7 +88,6 @@ def _get_dataframe():
|
|
88 |
headers=["prompt", "completion"],
|
89 |
wrap=True,
|
90 |
interactive=False,
|
91 |
-
elem_classes="table-view",
|
92 |
)
|
93 |
|
94 |
|
|
|
55 |
|
56 |
|
57 |
def generate_system_prompt(dataset_description, progress=gr.Progress()):
|
58 |
+
progress(0.0, desc="Starting")
|
59 |
+
progress(0.3, desc="Initializing")
|
60 |
generate_description = get_prompt_generator()
|
61 |
+
progress(0.7, desc="Generating")
|
62 |
result = next(
|
63 |
generate_description.process(
|
64 |
[
|
|
|
68 |
]
|
69 |
)
|
70 |
)[0]["generation"]
|
71 |
+
progress(1.0, desc="Prompt generated")
|
72 |
return result
|
73 |
|
74 |
|
|
|
88 |
headers=["prompt", "completion"],
|
89 |
wrap=True,
|
90 |
interactive=False,
|
|
|
91 |
)
|
92 |
|
93 |
|
src/synthetic_dataset_generator/apps/textcat.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
import json
|
|
|
2 |
import uuid
|
3 |
from typing import List, Union
|
4 |
|
@@ -41,15 +42,14 @@ def _get_dataframe():
|
|
41 |
headers=["labels", "text"],
|
42 |
wrap=True,
|
43 |
interactive=False,
|
44 |
-
elem_classes="table-view",
|
45 |
)
|
46 |
|
47 |
|
48 |
def generate_system_prompt(dataset_description, progress=gr.Progress()):
|
49 |
-
progress(0.0, desc="
|
50 |
-
progress(0.3, desc="Initializing
|
51 |
generate_description = get_prompt_generator()
|
52 |
-
progress(0.7, desc="Generating
|
53 |
result = next(
|
54 |
generate_description.process(
|
55 |
[
|
@@ -59,7 +59,7 @@ def generate_system_prompt(dataset_description, progress=gr.Progress()):
|
|
59 |
]
|
60 |
)
|
61 |
)[0]["generation"]
|
62 |
-
progress(1.0, desc="
|
63 |
data = json.loads(result)
|
64 |
system_prompt = data["classification_task"]
|
65 |
labels = data["labels"]
|
@@ -93,7 +93,7 @@ def generate_dataset(
|
|
93 |
is_sample: bool = False,
|
94 |
progress=gr.Progress(),
|
95 |
) -> pd.DataFrame:
|
96 |
-
progress(0.0, desc="(1/2) Generating
|
97 |
labels = get_preprocess_labels(labels)
|
98 |
textcat_generator = get_textcat_generator(
|
99 |
difficulty=difficulty,
|
@@ -116,13 +116,21 @@ def generate_dataset(
|
|
116 |
progress(
|
117 |
2 * 0.5 * n_processed / num_rows,
|
118 |
total=total_steps,
|
119 |
-
desc="(1/2) Generating
|
120 |
)
|
121 |
remaining_rows = num_rows - n_processed
|
122 |
batch_size = min(batch_size, remaining_rows)
|
123 |
-
inputs = [
|
124 |
-
|
125 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
126 |
batch = list(textcat_generator.process(inputs=inputs))
|
127 |
textcat_results.extend(batch[0])
|
128 |
n_processed += batch_size
|
@@ -130,14 +138,14 @@ def generate_dataset(
|
|
130 |
result["text"] = result["input_text"]
|
131 |
|
132 |
# label text classification data
|
133 |
-
progress(2 * 0.5, desc="(
|
134 |
n_processed = 0
|
135 |
labeller_results = []
|
136 |
while n_processed < num_rows:
|
137 |
progress(
|
138 |
0.5 + 0.5 * n_processed / num_rows,
|
139 |
total=total_steps,
|
140 |
-
desc="(
|
141 |
)
|
142 |
batch = textcat_results[n_processed : n_processed + batch_size]
|
143 |
labels_batch = list(labeller_generator.process(inputs=batch))
|
@@ -161,7 +169,19 @@ def generate_dataset(
|
|
161 |
dataframe["label"] = dataframe["label"].apply(
|
162 |
lambda x: x.lower().strip() if x.lower().strip() in labels else None
|
163 |
)
|
164 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
165 |
return dataframe
|
166 |
|
167 |
|
@@ -295,7 +315,7 @@ def push_dataset(
|
|
295 |
client=client,
|
296 |
)
|
297 |
rg_dataset = rg_dataset.create()
|
298 |
-
progress(0.7, desc="Pushing dataset
|
299 |
hf_dataset = Dataset.from_pandas(dataframe)
|
300 |
records = [
|
301 |
rg.Record(
|
@@ -326,7 +346,7 @@ def push_dataset(
|
|
326 |
for sample in hf_dataset
|
327 |
]
|
328 |
rg_dataset.records.log(records=records)
|
329 |
-
progress(1.0, desc="Dataset pushed
|
330 |
except Exception as e:
|
331 |
raise gr.Error(f"Error pushing dataset to Argilla: {e}")
|
332 |
return ""
|
@@ -385,61 +405,64 @@ with gr.Blocks() as app:
|
|
385 |
|
386 |
gr.HTML("<hr>")
|
387 |
gr.Markdown("## 2. Configure your dataset")
|
388 |
-
with gr.Row(equal_height=
|
389 |
-
with gr.
|
390 |
-
|
391 |
-
|
392 |
-
|
393 |
-
|
394 |
-
|
395 |
-
|
396 |
-
|
397 |
-
|
398 |
-
|
399 |
-
|
400 |
-
|
401 |
-
|
402 |
-
|
403 |
-
|
404 |
-
|
405 |
-
|
406 |
-
|
407 |
-
|
408 |
-
|
409 |
-
|
410 |
-
|
411 |
-
|
412 |
-
|
413 |
-
|
414 |
-
|
415 |
-
|
416 |
-
|
417 |
-
|
418 |
-
|
419 |
-
|
420 |
-
|
421 |
-
|
422 |
-
|
423 |
-
|
424 |
-
|
425 |
-
|
426 |
-
|
427 |
-
|
428 |
-
|
429 |
-
|
430 |
-
|
431 |
-
|
432 |
-
|
433 |
-
|
434 |
-
|
435 |
-
|
436 |
-
|
437 |
-
|
438 |
-
|
439 |
-
|
440 |
-
|
441 |
-
|
442 |
-
|
|
|
|
|
|
|
443 |
|
444 |
gr.HTML("<hr>")
|
445 |
gr.Markdown("## 3. Generate your dataset")
|
|
|
1 |
import json
|
2 |
+
import random
|
3 |
import uuid
|
4 |
from typing import List, Union
|
5 |
|
|
|
42 |
headers=["labels", "text"],
|
43 |
wrap=True,
|
44 |
interactive=False,
|
|
|
45 |
)
|
46 |
|
47 |
|
48 |
def generate_system_prompt(dataset_description, progress=gr.Progress()):
|
49 |
+
progress(0.0, desc="Starting")
|
50 |
+
progress(0.3, desc="Initializing")
|
51 |
generate_description = get_prompt_generator()
|
52 |
+
progress(0.7, desc="Generating")
|
53 |
result = next(
|
54 |
generate_description.process(
|
55 |
[
|
|
|
59 |
]
|
60 |
)
|
61 |
)[0]["generation"]
|
62 |
+
progress(1.0, desc="Prompt generated")
|
63 |
data = json.loads(result)
|
64 |
system_prompt = data["classification_task"]
|
65 |
labels = data["labels"]
|
|
|
93 |
is_sample: bool = False,
|
94 |
progress=gr.Progress(),
|
95 |
) -> pd.DataFrame:
|
96 |
+
progress(0.0, desc="(1/2) Generating dataset")
|
97 |
labels = get_preprocess_labels(labels)
|
98 |
textcat_generator = get_textcat_generator(
|
99 |
difficulty=difficulty,
|
|
|
116 |
progress(
|
117 |
2 * 0.5 * n_processed / num_rows,
|
118 |
total=total_steps,
|
119 |
+
desc="(1/2) Generating dataset",
|
120 |
)
|
121 |
remaining_rows = num_rows - n_processed
|
122 |
batch_size = min(batch_size, remaining_rows)
|
123 |
+
inputs = []
|
124 |
+
for _ in range(batch_size):
|
125 |
+
if num_labels == 1:
|
126 |
+
num_labels = 1
|
127 |
+
else:
|
128 |
+
num_labels = int(random.gammavariate(2, 2) * num_labels)
|
129 |
+
sampled_labels = random.sample(labels, num_labels)
|
130 |
+
random.shuffle(sampled_labels)
|
131 |
+
inputs.append(
|
132 |
+
{"task": f"{system_prompt}. Labels: {', '.join(sampled_labels)}"}
|
133 |
+
)
|
134 |
batch = list(textcat_generator.process(inputs=inputs))
|
135 |
textcat_results.extend(batch[0])
|
136 |
n_processed += batch_size
|
|
|
138 |
result["text"] = result["input_text"]
|
139 |
|
140 |
# label text classification data
|
141 |
+
progress(2 * 0.5, desc="(2/2) Labeling dataset")
|
142 |
n_processed = 0
|
143 |
labeller_results = []
|
144 |
while n_processed < num_rows:
|
145 |
progress(
|
146 |
0.5 + 0.5 * n_processed / num_rows,
|
147 |
total=total_steps,
|
148 |
+
desc="(2/2) Labeling dataset",
|
149 |
)
|
150 |
batch = textcat_results[n_processed : n_processed + batch_size]
|
151 |
labels_batch = list(labeller_generator.process(inputs=batch))
|
|
|
169 |
dataframe["label"] = dataframe["label"].apply(
|
170 |
lambda x: x.lower().strip() if x.lower().strip() in labels else None
|
171 |
)
|
172 |
+
else:
|
173 |
+
dataframe["labels"] = dataframe["labels"].apply(
|
174 |
+
lambda x: list(
|
175 |
+
set(
|
176 |
+
[
|
177 |
+
label.lower().strip()
|
178 |
+
for label in x
|
179 |
+
if label.lower().strip() in labels
|
180 |
+
]
|
181 |
+
)
|
182 |
+
)
|
183 |
+
)
|
184 |
+
progress(1.0, desc="Dataset created")
|
185 |
return dataframe
|
186 |
|
187 |
|
|
|
315 |
client=client,
|
316 |
)
|
317 |
rg_dataset = rg_dataset.create()
|
318 |
+
progress(0.7, desc="Pushing dataset")
|
319 |
hf_dataset = Dataset.from_pandas(dataframe)
|
320 |
records = [
|
321 |
rg.Record(
|
|
|
346 |
for sample in hf_dataset
|
347 |
]
|
348 |
rg_dataset.records.log(records=records)
|
349 |
+
progress(1.0, desc="Dataset pushed")
|
350 |
except Exception as e:
|
351 |
raise gr.Error(f"Error pushing dataset to Argilla: {e}")
|
352 |
return ""
|
|
|
405 |
|
406 |
gr.HTML("<hr>")
|
407 |
gr.Markdown("## 2. Configure your dataset")
|
408 |
+
with gr.Row(equal_height=True):
|
409 |
+
with gr.Row(equal_height=False):
|
410 |
+
with gr.Column(scale=2):
|
411 |
+
system_prompt = gr.Textbox(
|
412 |
+
label="System prompt",
|
413 |
+
placeholder="You are a helpful assistant.",
|
414 |
+
visible=True,
|
415 |
+
)
|
416 |
+
labels = gr.Dropdown(
|
417 |
+
choices=[],
|
418 |
+
allow_custom_value=True,
|
419 |
+
interactive=True,
|
420 |
+
label="Labels",
|
421 |
+
multiselect=True,
|
422 |
+
info="Add the labels to classify the text.",
|
423 |
+
)
|
424 |
+
num_labels = gr.Number(
|
425 |
+
label="Number of labels per text",
|
426 |
+
value=1,
|
427 |
+
minimum=1,
|
428 |
+
maximum=10,
|
429 |
+
info="Select 1 for single-label and >1 for multi-label.",
|
430 |
+
interactive=True,
|
431 |
+
)
|
432 |
+
clarity = gr.Dropdown(
|
433 |
+
choices=[
|
434 |
+
("Clear", "clear"),
|
435 |
+
(
|
436 |
+
"Understandable",
|
437 |
+
"understandable with some effort",
|
438 |
+
),
|
439 |
+
("Ambiguous", "ambiguous"),
|
440 |
+
("Mixed", "mixed"),
|
441 |
+
],
|
442 |
+
value="understandable with some effort",
|
443 |
+
label="Clarity",
|
444 |
+
info="Set how easily the correct label or labels can be identified.",
|
445 |
+
interactive=True,
|
446 |
+
)
|
447 |
+
difficulty = gr.Dropdown(
|
448 |
+
choices=[
|
449 |
+
("High School", "high school"),
|
450 |
+
("College", "college"),
|
451 |
+
("PhD", "PhD"),
|
452 |
+
("Mixed", "mixed"),
|
453 |
+
],
|
454 |
+
value="high school",
|
455 |
+
label="Difficulty",
|
456 |
+
info="Select the comprehension level for the text. Ensure it matches the task context.",
|
457 |
+
interactive=True,
|
458 |
+
)
|
459 |
+
with gr.Row():
|
460 |
+
clear_btn_full = gr.Button("Clear", variant="secondary")
|
461 |
+
btn_apply_to_sample_dataset = gr.Button(
|
462 |
+
"Save", variant="primary"
|
463 |
+
)
|
464 |
+
with gr.Column(scale=3):
|
465 |
+
dataframe = _get_dataframe()
|
466 |
|
467 |
gr.HTML("<hr>")
|
468 |
gr.Markdown("## 3. Generate your dataset")
|
src/synthetic_dataset_generator/pipelines/textcat.py
CHANGED
@@ -15,35 +15,29 @@ from synthetic_dataset_generator.utils import get_preprocess_labels
|
|
15 |
|
16 |
PROMPT_CREATION_PROMPT = """You are an AI assistant specialized in generating very precise text classification tasks for dataset creation.
|
17 |
|
18 |
-
Your
|
19 |
|
20 |
-
The prompt
|
21 |
|
22 |
-
|
23 |
-
|
24 |
-
{"classification_task": "Classify the following customer review of a cinema as", "labels": ["positive", "negative"]}
|
25 |
-
|
26 |
-
{"classification_task": "Categorize the following news article into one or more of the following categories:", "labels": ["politics", "sports", "technology", "entertainment", "health", "business", "environment", "education", "science", "international"]}
|
27 |
-
|
28 |
-
{"classification_task": "Classify the following news article into one or more of the following categories:", "labels": ['politics', 'sports', 'technology', 'entertainment', 'health', 'business', 'environment', 'education', 'science', 'international']}
|
29 |
|
30 |
-
|
31 |
|
32 |
-
|
33 |
-
|
34 |
-
{"classification_task": "Classify the following movie review into one of the following categories:", "labels": ['critical', 'praise', 'disappointed', 'enthusiastic']}
|
35 |
-
|
36 |
-
{"classification_task": "Categorize the following customer service transcript into one of the following categories:", "labels": ['satisfied', 'dissatisfied', 'highly-satisfied', 'somewhat-dissatisfied', 'indifferent']}
|
37 |
|
38 |
-
|
|
|
39 |
|
40 |
-
|
|
|
41 |
|
42 |
-
|
|
|
43 |
|
44 |
-
|
|
|
45 |
|
46 |
-
|
47 |
"""
|
48 |
|
49 |
DEFAULT_DATASET_DESCRIPTIONS = [
|
|
|
15 |
|
16 |
PROMPT_CREATION_PROMPT = """You are an AI assistant specialized in generating very precise text classification tasks for dataset creation.
|
17 |
|
18 |
+
Your should write a prompt following a the dataset description. Respond with the prompt and nothing else.
|
19 |
|
20 |
+
The prompt should follow the same style and structure as the following example prompts, clearly specifying the possible classification labels.
|
21 |
|
22 |
+
Make sure to always include all of the detailed information from the description and the context of the company that is provided.
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
|
24 |
+
Don't include the labels in the classification_task but only provide a high level description of the classification task.
|
25 |
|
26 |
+
If a label is composed of multiple words, use a hyphen to separate them. For example, 'smartphone-review', 'customer-service', 'product-quality'.:
|
|
|
|
|
|
|
|
|
27 |
|
28 |
+
Description: DavidMovieHouse is a cinema that has been in business for 10 years.
|
29 |
+
Output: {"classification_task": "The company DavidMovieHouse is a cinema that has been in business for 10 years and has had customers reviews. Classify the customer reviews as", "labels": ["positive", "negative"]}
|
30 |
|
31 |
+
Description: A dataset that focuses on creating neo-ludite discussions about technologies within the AI space.
|
32 |
+
Output: {"classification_task": "Neo-ludiite discussions about technologies within the AI space cover. Categorize the discussions into one of the following categories", "labels": ["tech-support", "tech-opposition"]}
|
33 |
|
34 |
+
Description: A dataset that covers the articles of a niche sports website called TheSportBlogs that focuses on female sports within the ballsport domain for the US market.
|
35 |
+
Output: {"classification_task": "TechSportBlogs is a niche sports website that focuses on female sports within the ballsport domain for the US market. Determine the category of based on the article using the following categories", "labels": ["basketball", "volleyball", "tennis", "hockey", "baseball", "soccer"]}
|
36 |
|
37 |
+
Description: A dataset covering customer reviews for an e-commerce website called Argilla that sells technology datasets within the open source Natural Language Processing space and has review with labels "data-quality", "data-accuracy", "customer-service", "price", "product-availability", "shipping-speed"
|
38 |
+
Output: {"classification_task": "A dataset covering customer reviews for an e-commerce website called Argilla that sells technology datasets within the open source Natural Language Processing space and has review with labels", "labels": ["data-quality", "data-accuracy", "customer-service", "price", "product-availability", "shipping-speed"]}
|
39 |
|
40 |
+
Description:
|
41 |
"""
|
42 |
|
43 |
DEFAULT_DATASET_DESCRIPTIONS = [
|