davidberenstein1957 HF staff commited on
Commit
d15b1c7
·
unverified ·
2 Parent(s): 79801ad 136bd13

Merge pull request #6 from argilla-io/feat/improve-textcat

Browse files
src/synthetic_dataset_generator/app.py CHANGED
@@ -1,4 +1,5 @@
1
  from synthetic_dataset_generator._tabbedinterface import TabbedInterface
 
2
  # from synthetic_dataset_generator.apps.eval import app as eval_app
3
  from synthetic_dataset_generator.apps.readme import app as readme_app
4
  from synthetic_dataset_generator.apps.sft import app as sft_app
@@ -15,9 +16,6 @@ button[role="tab"][aria-selected="true"]:hover {border-color: var(--button-prima
15
  #system_prompt_examples { color: var(--body-text-color) !important; background-color: var(--block-background-fill) !important;}
16
  .container {padding-inline: 0 !important}
17
  #sign_in_button { flex-grow: 0; width: auto !important; display: flex; align-items: center; justify-content: center; margin: 0 auto; }
18
- .table-view .table-wrap {
19
- max-height: 450px;
20
- }
21
  """
22
 
23
  image = """<br><img src="https://raw.githubusercontent.com/argilla-io/synthetic-data-generator/main/assets/logo.svg" alt="Synthetic Data Generator Logo" style="display: block; margin-left: auto; margin-right: auto; width: clamp(50%, 400px, 100%)"/>"""
 
1
  from synthetic_dataset_generator._tabbedinterface import TabbedInterface
2
+
3
  # from synthetic_dataset_generator.apps.eval import app as eval_app
4
  from synthetic_dataset_generator.apps.readme import app as readme_app
5
  from synthetic_dataset_generator.apps.sft import app as sft_app
 
16
  #system_prompt_examples { color: var(--body-text-color) !important; background-color: var(--block-background-fill) !important;}
17
  .container {padding-inline: 0 !important}
18
  #sign_in_button { flex-grow: 0; width: auto !important; display: flex; align-items: center; justify-content: center; margin: 0 auto; }
 
 
 
19
  """
20
 
21
  image = """<br><img src="https://raw.githubusercontent.com/argilla-io/synthetic-data-generator/main/assets/logo.svg" alt="Synthetic Data Generator Logo" style="display: block; margin-left: auto; margin-right: auto; width: clamp(50%, 400px, 100%)"/>"""
src/synthetic_dataset_generator/apps/eval.py CHANGED
@@ -750,7 +750,6 @@ with gr.Blocks() as app:
750
  headers=["prompt", "completion", "evaluation"],
751
  wrap=True,
752
  interactive=False,
753
- elem_classes="table-view",
754
  )
755
 
756
  gr.HTML(value="<hr>")
 
750
  headers=["prompt", "completion", "evaluation"],
751
  wrap=True,
752
  interactive=False,
 
753
  )
754
 
755
  gr.HTML(value="<hr>")
src/synthetic_dataset_generator/apps/sft.py CHANGED
@@ -55,10 +55,10 @@ def convert_dataframe_messages(dataframe: pd.DataFrame) -> pd.DataFrame:
55
 
56
 
57
  def generate_system_prompt(dataset_description, progress=gr.Progress()):
58
- progress(0.0, desc="Generating system prompt")
59
- progress(0.3, desc="Initializing text generation")
60
  generate_description = get_prompt_generator()
61
- progress(0.7, desc="Generating system prompt")
62
  result = next(
63
  generate_description.process(
64
  [
@@ -68,7 +68,7 @@ def generate_system_prompt(dataset_description, progress=gr.Progress()):
68
  ]
69
  )
70
  )[0]["generation"]
71
- progress(1.0, desc="System prompt generated")
72
  return result
73
 
74
 
@@ -88,7 +88,6 @@ def _get_dataframe():
88
  headers=["prompt", "completion"],
89
  wrap=True,
90
  interactive=False,
91
- elem_classes="table-view",
92
  )
93
 
94
 
 
55
 
56
 
57
  def generate_system_prompt(dataset_description, progress=gr.Progress()):
58
+ progress(0.0, desc="Starting")
59
+ progress(0.3, desc="Initializing")
60
  generate_description = get_prompt_generator()
61
+ progress(0.7, desc="Generating")
62
  result = next(
63
  generate_description.process(
64
  [
 
68
  ]
69
  )
70
  )[0]["generation"]
71
+ progress(1.0, desc="Prompt generated")
72
  return result
73
 
74
 
 
88
  headers=["prompt", "completion"],
89
  wrap=True,
90
  interactive=False,
 
91
  )
92
 
93
 
src/synthetic_dataset_generator/apps/textcat.py CHANGED
@@ -1,4 +1,5 @@
1
  import json
 
2
  import uuid
3
  from typing import List, Union
4
 
@@ -41,15 +42,14 @@ def _get_dataframe():
41
  headers=["labels", "text"],
42
  wrap=True,
43
  interactive=False,
44
- elem_classes="table-view",
45
  )
46
 
47
 
48
  def generate_system_prompt(dataset_description, progress=gr.Progress()):
49
- progress(0.0, desc="Generating text classification task")
50
- progress(0.3, desc="Initializing text generation")
51
  generate_description = get_prompt_generator()
52
- progress(0.7, desc="Generating text classification task")
53
  result = next(
54
  generate_description.process(
55
  [
@@ -59,7 +59,7 @@ def generate_system_prompt(dataset_description, progress=gr.Progress()):
59
  ]
60
  )
61
  )[0]["generation"]
62
- progress(1.0, desc="Text classification task generated")
63
  data = json.loads(result)
64
  system_prompt = data["classification_task"]
65
  labels = data["labels"]
@@ -93,7 +93,7 @@ def generate_dataset(
93
  is_sample: bool = False,
94
  progress=gr.Progress(),
95
  ) -> pd.DataFrame:
96
- progress(0.0, desc="(1/2) Generating text classification data")
97
  labels = get_preprocess_labels(labels)
98
  textcat_generator = get_textcat_generator(
99
  difficulty=difficulty,
@@ -116,13 +116,21 @@ def generate_dataset(
116
  progress(
117
  2 * 0.5 * n_processed / num_rows,
118
  total=total_steps,
119
- desc="(1/2) Generating text classification data",
120
  )
121
  remaining_rows = num_rows - n_processed
122
  batch_size = min(batch_size, remaining_rows)
123
- inputs = [
124
- {"task": f"{system_prompt} {', '.join(labels)}"} for _ in range(batch_size)
125
- ]
 
 
 
 
 
 
 
 
126
  batch = list(textcat_generator.process(inputs=inputs))
127
  textcat_results.extend(batch[0])
128
  n_processed += batch_size
@@ -130,14 +138,14 @@ def generate_dataset(
130
  result["text"] = result["input_text"]
131
 
132
  # label text classification data
133
- progress(2 * 0.5, desc="(1/2) Generating text classification data")
134
  n_processed = 0
135
  labeller_results = []
136
  while n_processed < num_rows:
137
  progress(
138
  0.5 + 0.5 * n_processed / num_rows,
139
  total=total_steps,
140
- desc="(1/2) Labeling text classification data",
141
  )
142
  batch = textcat_results[n_processed : n_processed + batch_size]
143
  labels_batch = list(labeller_generator.process(inputs=batch))
@@ -161,7 +169,19 @@ def generate_dataset(
161
  dataframe["label"] = dataframe["label"].apply(
162
  lambda x: x.lower().strip() if x.lower().strip() in labels else None
163
  )
164
- progress(1.0, desc="Dataset generation completed")
 
 
 
 
 
 
 
 
 
 
 
 
165
  return dataframe
166
 
167
 
@@ -295,7 +315,7 @@ def push_dataset(
295
  client=client,
296
  )
297
  rg_dataset = rg_dataset.create()
298
- progress(0.7, desc="Pushing dataset to Argilla")
299
  hf_dataset = Dataset.from_pandas(dataframe)
300
  records = [
301
  rg.Record(
@@ -326,7 +346,7 @@ def push_dataset(
326
  for sample in hf_dataset
327
  ]
328
  rg_dataset.records.log(records=records)
329
- progress(1.0, desc="Dataset pushed to Argilla")
330
  except Exception as e:
331
  raise gr.Error(f"Error pushing dataset to Argilla: {e}")
332
  return ""
@@ -385,61 +405,64 @@ with gr.Blocks() as app:
385
 
386
  gr.HTML("<hr>")
387
  gr.Markdown("## 2. Configure your dataset")
388
- with gr.Row(equal_height=False):
389
- with gr.Column(scale=2):
390
- system_prompt = gr.Textbox(
391
- label="System prompt",
392
- placeholder="You are a helpful assistant.",
393
- visible=True,
394
- )
395
- labels = gr.Dropdown(
396
- choices=[],
397
- allow_custom_value=True,
398
- interactive=True,
399
- label="Labels",
400
- multiselect=True,
401
- info="Add the labels to classify the text.",
402
- )
403
- num_labels = gr.Number(
404
- label="Number of labels per text",
405
- value=1,
406
- minimum=1,
407
- maximum=10,
408
- info="Select 1 for single-label and >1 for multi-label.",
409
- interactive=True,
410
- )
411
- clarity = gr.Dropdown(
412
- choices=[
413
- ("Clear", "clear"),
414
- (
415
- "Understandable",
416
- "understandable with some effort",
417
- ),
418
- ("Ambiguous", "ambiguous"),
419
- ("Mixed", "mixed"),
420
- ],
421
- value="mixed",
422
- label="Clarity",
423
- info="Set how easily the correct label or labels can be identified.",
424
- interactive=True,
425
- )
426
- difficulty = gr.Dropdown(
427
- choices=[
428
- ("High School", "high school"),
429
- ("College", "college"),
430
- ("PhD", "PhD"),
431
- ("Mixed", "mixed"),
432
- ],
433
- value="mixed",
434
- label="Difficulty",
435
- info="Select the comprehension level for the text. Ensure it matches the task context.",
436
- interactive=True,
437
- )
438
- with gr.Row():
439
- clear_btn_full = gr.Button("Clear", variant="secondary")
440
- btn_apply_to_sample_dataset = gr.Button("Save", variant="primary")
441
- with gr.Column(scale=3):
442
- dataframe = _get_dataframe()
 
 
 
443
 
444
  gr.HTML("<hr>")
445
  gr.Markdown("## 3. Generate your dataset")
 
1
  import json
2
+ import random
3
  import uuid
4
  from typing import List, Union
5
 
 
42
  headers=["labels", "text"],
43
  wrap=True,
44
  interactive=False,
 
45
  )
46
 
47
 
48
  def generate_system_prompt(dataset_description, progress=gr.Progress()):
49
+ progress(0.0, desc="Starting")
50
+ progress(0.3, desc="Initializing")
51
  generate_description = get_prompt_generator()
52
+ progress(0.7, desc="Generating")
53
  result = next(
54
  generate_description.process(
55
  [
 
59
  ]
60
  )
61
  )[0]["generation"]
62
+ progress(1.0, desc="Prompt generated")
63
  data = json.loads(result)
64
  system_prompt = data["classification_task"]
65
  labels = data["labels"]
 
93
  is_sample: bool = False,
94
  progress=gr.Progress(),
95
  ) -> pd.DataFrame:
96
+ progress(0.0, desc="(1/2) Generating dataset")
97
  labels = get_preprocess_labels(labels)
98
  textcat_generator = get_textcat_generator(
99
  difficulty=difficulty,
 
116
  progress(
117
  2 * 0.5 * n_processed / num_rows,
118
  total=total_steps,
119
+ desc="(1/2) Generating dataset",
120
  )
121
  remaining_rows = num_rows - n_processed
122
  batch_size = min(batch_size, remaining_rows)
123
+ inputs = []
124
+ for _ in range(batch_size):
125
+ if num_labels == 1:
126
+ num_labels = 1
127
+ else:
128
+ num_labels = int(random.gammavariate(2, 2) * num_labels)
129
+ sampled_labels = random.sample(labels, num_labels)
130
+ random.shuffle(sampled_labels)
131
+ inputs.append(
132
+ {"task": f"{system_prompt}. Labels: {', '.join(sampled_labels)}"}
133
+ )
134
  batch = list(textcat_generator.process(inputs=inputs))
135
  textcat_results.extend(batch[0])
136
  n_processed += batch_size
 
138
  result["text"] = result["input_text"]
139
 
140
  # label text classification data
141
+ progress(2 * 0.5, desc="(2/2) Labeling dataset")
142
  n_processed = 0
143
  labeller_results = []
144
  while n_processed < num_rows:
145
  progress(
146
  0.5 + 0.5 * n_processed / num_rows,
147
  total=total_steps,
148
+ desc="(2/2) Labeling dataset",
149
  )
150
  batch = textcat_results[n_processed : n_processed + batch_size]
151
  labels_batch = list(labeller_generator.process(inputs=batch))
 
169
  dataframe["label"] = dataframe["label"].apply(
170
  lambda x: x.lower().strip() if x.lower().strip() in labels else None
171
  )
172
+ else:
173
+ dataframe["labels"] = dataframe["labels"].apply(
174
+ lambda x: list(
175
+ set(
176
+ [
177
+ label.lower().strip()
178
+ for label in x
179
+ if label.lower().strip() in labels
180
+ ]
181
+ )
182
+ )
183
+ )
184
+ progress(1.0, desc="Dataset created")
185
  return dataframe
186
 
187
 
 
315
  client=client,
316
  )
317
  rg_dataset = rg_dataset.create()
318
+ progress(0.7, desc="Pushing dataset")
319
  hf_dataset = Dataset.from_pandas(dataframe)
320
  records = [
321
  rg.Record(
 
346
  for sample in hf_dataset
347
  ]
348
  rg_dataset.records.log(records=records)
349
+ progress(1.0, desc="Dataset pushed")
350
  except Exception as e:
351
  raise gr.Error(f"Error pushing dataset to Argilla: {e}")
352
  return ""
 
405
 
406
  gr.HTML("<hr>")
407
  gr.Markdown("## 2. Configure your dataset")
408
+ with gr.Row(equal_height=True):
409
+ with gr.Row(equal_height=False):
410
+ with gr.Column(scale=2):
411
+ system_prompt = gr.Textbox(
412
+ label="System prompt",
413
+ placeholder="You are a helpful assistant.",
414
+ visible=True,
415
+ )
416
+ labels = gr.Dropdown(
417
+ choices=[],
418
+ allow_custom_value=True,
419
+ interactive=True,
420
+ label="Labels",
421
+ multiselect=True,
422
+ info="Add the labels to classify the text.",
423
+ )
424
+ num_labels = gr.Number(
425
+ label="Number of labels per text",
426
+ value=1,
427
+ minimum=1,
428
+ maximum=10,
429
+ info="Select 1 for single-label and >1 for multi-label.",
430
+ interactive=True,
431
+ )
432
+ clarity = gr.Dropdown(
433
+ choices=[
434
+ ("Clear", "clear"),
435
+ (
436
+ "Understandable",
437
+ "understandable with some effort",
438
+ ),
439
+ ("Ambiguous", "ambiguous"),
440
+ ("Mixed", "mixed"),
441
+ ],
442
+ value="understandable with some effort",
443
+ label="Clarity",
444
+ info="Set how easily the correct label or labels can be identified.",
445
+ interactive=True,
446
+ )
447
+ difficulty = gr.Dropdown(
448
+ choices=[
449
+ ("High School", "high school"),
450
+ ("College", "college"),
451
+ ("PhD", "PhD"),
452
+ ("Mixed", "mixed"),
453
+ ],
454
+ value="high school",
455
+ label="Difficulty",
456
+ info="Select the comprehension level for the text. Ensure it matches the task context.",
457
+ interactive=True,
458
+ )
459
+ with gr.Row():
460
+ clear_btn_full = gr.Button("Clear", variant="secondary")
461
+ btn_apply_to_sample_dataset = gr.Button(
462
+ "Save", variant="primary"
463
+ )
464
+ with gr.Column(scale=3):
465
+ dataframe = _get_dataframe()
466
 
467
  gr.HTML("<hr>")
468
  gr.Markdown("## 3. Generate your dataset")
src/synthetic_dataset_generator/pipelines/textcat.py CHANGED
@@ -15,35 +15,29 @@ from synthetic_dataset_generator.utils import get_preprocess_labels
15
 
16
  PROMPT_CREATION_PROMPT = """You are an AI assistant specialized in generating very precise text classification tasks for dataset creation.
17
 
18
- Your task is to write a prompt following the instruction of the user. Respond with the prompt and nothing else.
19
 
20
- The prompt you write should follow the same style and structure as the following example prompts, clearly specifying the possible classification labels.
21
 
22
- If a label is composed of multiple words, use a hyphen to separate them. For example, 'smartphone-review', 'customer-service', 'product-quality'.:
23
-
24
- {"classification_task": "Classify the following customer review of a cinema as", "labels": ["positive", "negative"]}
25
-
26
- {"classification_task": "Categorize the following news article into one or more of the following categories:", "labels": ["politics", "sports", "technology", "entertainment", "health", "business", "environment", "education", "science", "international"]}
27
-
28
- {"classification_task": "Classify the following news article into one or more of the following categories:", "labels": ['politics', 'sports', 'technology', 'entertainment', 'health', 'business', 'environment', 'education', 'science', 'international']}
29
 
30
- {"classification_task": "Determine the sentiment of the following social media post:", "labels": ['ambiguous', 'sarcastic', 'informative', 'emotional']}
31
 
32
- {"classification_task": "Identify the issue category for the following technical support ticket:", "labels": ['billing', 'technical', 'account', 'shipping', 'returns', 'installation', 'subscription']}
33
-
34
- {"classification_task": "Classify the following movie review into one of the following categories:", "labels": ['critical', 'praise', 'disappointed', 'enthusiastic']}
35
-
36
- {"classification_task": "Categorize the following customer service transcript into one of the following categories:", "labels": ['satisfied', 'dissatisfied', 'highly-satisfied', 'somewhat-dissatisfied', 'indifferent']}
37
 
38
- {"classification_task": "Classify the following product description into one of the following product types:", "labels": ['smartphone', 'laptop', 'tablet', 'smartwatch', 'e-reader', 'headphones']}
 
39
 
40
- {"classification_task": "Categorize the following tweet expressing the political event discussed as", "labels": ['support', 'opposition']}
 
41
 
42
- {"classification_task": "Classify the following restaurant review into one of the following categories:", "labels": ['food-quality', 'service', 'ambiance', 'price']}
 
43
 
44
- {"classification_task": "Categorize the following blog post based on its primary fashion trend or style:", "labels": ['casual', 'formal', 'streetwear', 'vintage', 'sustainable-fashion']}
 
45
 
46
- User dataset description:
47
  """
48
 
49
  DEFAULT_DATASET_DESCRIPTIONS = [
 
15
 
16
  PROMPT_CREATION_PROMPT = """You are an AI assistant specialized in generating very precise text classification tasks for dataset creation.
17
 
18
+ Your should write a prompt following a the dataset description. Respond with the prompt and nothing else.
19
 
20
+ The prompt should follow the same style and structure as the following example prompts, clearly specifying the possible classification labels.
21
 
22
+ Make sure to always include all of the detailed information from the description and the context of the company that is provided.
 
 
 
 
 
 
23
 
24
+ Don't include the labels in the classification_task but only provide a high level description of the classification task.
25
 
26
+ If a label is composed of multiple words, use a hyphen to separate them. For example, 'smartphone-review', 'customer-service', 'product-quality'.:
 
 
 
 
27
 
28
+ Description: DavidMovieHouse is a cinema that has been in business for 10 years.
29
+ Output: {"classification_task": "The company DavidMovieHouse is a cinema that has been in business for 10 years and has had customers reviews. Classify the customer reviews as", "labels": ["positive", "negative"]}
30
 
31
+ Description: A dataset that focuses on creating neo-ludite discussions about technologies within the AI space.
32
+ Output: {"classification_task": "Neo-ludiite discussions about technologies within the AI space cover. Categorize the discussions into one of the following categories", "labels": ["tech-support", "tech-opposition"]}
33
 
34
+ Description: A dataset that covers the articles of a niche sports website called TheSportBlogs that focuses on female sports within the ballsport domain for the US market.
35
+ Output: {"classification_task": "TechSportBlogs is a niche sports website that focuses on female sports within the ballsport domain for the US market. Determine the category of based on the article using the following categories", "labels": ["basketball", "volleyball", "tennis", "hockey", "baseball", "soccer"]}
36
 
37
+ Description: A dataset covering customer reviews for an e-commerce website called Argilla that sells technology datasets within the open source Natural Language Processing space and has review with labels "data-quality", "data-accuracy", "customer-service", "price", "product-availability", "shipping-speed"
38
+ Output: {"classification_task": "A dataset covering customer reviews for an e-commerce website called Argilla that sells technology datasets within the open source Natural Language Processing space and has review with labels", "labels": ["data-quality", "data-accuracy", "customer-service", "price", "product-availability", "shipping-speed"]}
39
 
40
+ Description:
41
  """
42
 
43
  DEFAULT_DATASET_DESCRIPTIONS = [