davidberenstein1957 HF staff commited on
Commit
ec82642
·
1 Parent(s): 7b7c1be

add multi-label support instead of num labels

Browse files
src/synthetic_dataset_generator/apps/textcat.py CHANGED
@@ -69,14 +69,14 @@ def generate_system_prompt(dataset_description, progress=gr.Progress()):
69
 
70
 
71
  def generate_sample_dataset(
72
- system_prompt, difficulty, clarity, labels, num_labels, progress=gr.Progress()
73
  ):
74
  dataframe = generate_dataset(
75
  system_prompt=system_prompt,
76
  difficulty=difficulty,
77
  clarity=clarity,
78
  labels=labels,
79
- num_labels=num_labels,
80
  num_rows=10,
81
  progress=progress,
82
  is_sample=True,
@@ -89,7 +89,7 @@ def generate_dataset(
89
  difficulty: str,
90
  clarity: str,
91
  labels: List[str] = None,
92
- num_labels: int = 1,
93
  num_rows: int = 10,
94
  temperature: float = 0.9,
95
  is_sample: bool = False,
@@ -107,7 +107,7 @@ def generate_dataset(
107
  labeller_generator = get_labeller_generator(
108
  system_prompt=f"{system_prompt} {', '.join(labels)}",
109
  labels=labels,
110
- num_labels=num_labels,
111
  )
112
  total_steps: int = num_rows * 2
113
  batch_size = DEFAULT_BATCH_SIZE
@@ -125,11 +125,11 @@ def generate_dataset(
125
  batch_size = min(batch_size, remaining_rows)
126
  inputs = []
127
  for _ in range(batch_size):
128
- if num_labels == 1:
129
- num_labels = 1
130
  else:
131
- num_labels = int(random.gammavariate(2, 2) * num_labels)
132
- sampled_labels = random.sample(labels, num_labels)
133
  random.shuffle(sampled_labels)
134
  inputs.append(
135
  {
@@ -169,12 +169,7 @@ def generate_dataset(
169
  distiset_results.append(record)
170
 
171
  dataframe = pd.DataFrame(distiset_results)
172
- if num_labels == 1:
173
- dataframe = dataframe.rename(columns={"labels": "label"})
174
- dataframe["label"] = dataframe["label"].apply(
175
- lambda x: x.lower().strip() if x.lower().strip() in labels else None
176
- )
177
- else:
178
  dataframe["labels"] = dataframe["labels"].apply(
179
  lambda x: list(
180
  set(
@@ -186,6 +181,12 @@ def generate_dataset(
186
  )
187
  )
188
  )
 
 
 
 
 
 
189
  progress(1.0, desc="Dataset created")
190
  return dataframe
191
 
@@ -194,7 +195,7 @@ def push_dataset_to_hub(
194
  dataframe: pd.DataFrame,
195
  org_name: str,
196
  repo_name: str,
197
- num_labels: int = 1,
198
  labels: List[str] = None,
199
  oauth_token: Union[gr.OAuthToken, None] = None,
200
  private: bool = False,
@@ -206,18 +207,17 @@ def push_dataset_to_hub(
206
  progress(0.3, desc="Preprocessing")
207
  labels = get_preprocess_labels(labels)
208
  progress(0.7, desc="Creating dataset")
209
- if num_labels == 1:
210
- dataframe["label"] = dataframe["label"].replace("", None)
211
- features = Features(
212
- {"text": Value("string"), "label": ClassLabel(names=labels)}
213
- )
214
- else:
215
  features = Features(
216
  {
217
  "text": Value("string"),
218
  "labels": Sequence(feature=ClassLabel(names=labels)),
219
  }
220
  )
 
 
 
 
221
  dataset = Dataset.from_pandas(dataframe, features=features)
222
  dataset = combine_datasets(repo_id, dataset)
223
  distiset = Distiset({"default": dataset})
@@ -239,7 +239,7 @@ def push_dataset(
239
  system_prompt: str,
240
  difficulty: str,
241
  clarity: str,
242
- num_labels: int = 1,
243
  num_rows: int = 10,
244
  labels: List[str] = None,
245
  private: bool = False,
@@ -252,7 +252,7 @@ def push_dataset(
252
  system_prompt=system_prompt,
253
  difficulty=difficulty,
254
  clarity=clarity,
255
- num_labels=num_labels,
256
  labels=labels,
257
  num_rows=num_rows,
258
  temperature=temperature,
@@ -261,7 +261,7 @@ def push_dataset(
261
  dataframe,
262
  org_name,
263
  repo_name,
264
- num_labels,
265
  labels,
266
  oauth_token,
267
  private,
@@ -288,19 +288,19 @@ def push_dataset(
288
  ],
289
  questions=[
290
  (
291
- rg.LabelQuestion(
292
- name="label",
293
- title="Label",
294
- description="The label of the text",
295
- labels=labels,
296
- )
297
- if num_labels == 1
298
- else rg.MultiLabelQuestion(
299
  name="labels",
300
  title="Labels",
301
  description="The labels of the conversation",
302
  labels=labels,
303
  )
 
 
 
 
 
 
 
304
  ),
305
  ],
306
  metadata=[
@@ -340,16 +340,16 @@ def push_dataset(
340
  suggestions=(
341
  [
342
  rg.Suggestion(
343
- question_name="label" if num_labels == 1 else "labels",
344
  value=(
345
- sample["label"] if num_labels == 1 else sample["labels"]
346
  ),
347
  )
348
  ]
349
  if (
350
- (num_labels == 1 and sample["label"] in labels)
351
  or (
352
- num_labels > 1
353
  and all(label in labels for label in sample["labels"])
354
  )
355
  )
@@ -373,10 +373,6 @@ def validate_input_labels(labels):
373
  return labels
374
 
375
 
376
- def update_max_num_labels(labels):
377
- return gr.update(maximum=len(labels) if labels else 1)
378
-
379
-
380
  def show_pipeline_code_visibility():
381
  return {pipeline_code_ui: gr.Accordion(visible=True)}
382
 
@@ -434,13 +430,11 @@ with gr.Blocks() as app:
434
  multiselect=True,
435
  info="Add the labels to classify the text.",
436
  )
437
- num_labels = gr.Number(
438
- label="Number of labels per text",
439
- value=1,
440
- minimum=1,
441
- maximum=10,
442
- info="Select 1 for single-label and >1 for multi-label.",
443
  interactive=True,
 
444
  )
445
  clarity = gr.Dropdown(
446
  choices=[
@@ -521,7 +515,7 @@ with gr.Blocks() as app:
521
  difficulty=difficulty.value,
522
  clarity=clarity.value,
523
  labels=labels.value,
524
- num_labels=num_labels.value,
525
  num_rows=num_rows.value,
526
  temperature=temperature.value,
527
  )
@@ -538,24 +532,14 @@ with gr.Blocks() as app:
538
  show_progress=True,
539
  ).then(
540
  fn=generate_sample_dataset,
541
- inputs=[system_prompt, difficulty, clarity, labels, num_labels],
542
  outputs=[dataframe],
543
  show_progress=True,
544
- ).then(
545
- fn=update_max_num_labels,
546
- inputs=[labels],
547
- outputs=[num_labels],
548
- )
549
-
550
- labels.input(
551
- fn=update_max_num_labels,
552
- inputs=[labels],
553
- outputs=[num_labels],
554
  )
555
 
556
  btn_apply_to_sample_dataset.click(
557
  fn=generate_sample_dataset,
558
- inputs=[system_prompt, difficulty, clarity, labels, num_labels],
559
  outputs=[dataframe],
560
  show_progress=True,
561
  )
@@ -586,7 +570,7 @@ with gr.Blocks() as app:
586
  system_prompt,
587
  difficulty,
588
  clarity,
589
- num_labels,
590
  num_rows,
591
  labels,
592
  private,
@@ -606,7 +590,7 @@ with gr.Blocks() as app:
606
  difficulty,
607
  clarity,
608
  labels,
609
- num_labels,
610
  num_rows,
611
  temperature,
612
  ],
 
69
 
70
 
71
  def generate_sample_dataset(
72
+ system_prompt, difficulty, clarity, labels, multi_label, progress=gr.Progress()
73
  ):
74
  dataframe = generate_dataset(
75
  system_prompt=system_prompt,
76
  difficulty=difficulty,
77
  clarity=clarity,
78
  labels=labels,
79
+ multi_label=multi_label,
80
  num_rows=10,
81
  progress=progress,
82
  is_sample=True,
 
89
  difficulty: str,
90
  clarity: str,
91
  labels: List[str] = None,
92
+ multi_label: bool = False,
93
  num_rows: int = 10,
94
  temperature: float = 0.9,
95
  is_sample: bool = False,
 
107
  labeller_generator = get_labeller_generator(
108
  system_prompt=f"{system_prompt} {', '.join(labels)}",
109
  labels=labels,
110
+ multi_label=multi_label,
111
  )
112
  total_steps: int = num_rows * 2
113
  batch_size = DEFAULT_BATCH_SIZE
 
125
  batch_size = min(batch_size, remaining_rows)
126
  inputs = []
127
  for _ in range(batch_size):
128
+ if multi_label:
129
+ k = int(random.gammavariate(2, 2) * len(labels))
130
  else:
131
+ k = 1
132
+ sampled_labels = random.sample(labels, k)
133
  random.shuffle(sampled_labels)
134
  inputs.append(
135
  {
 
169
  distiset_results.append(record)
170
 
171
  dataframe = pd.DataFrame(distiset_results)
172
+ if multi_label:
 
 
 
 
 
173
  dataframe["labels"] = dataframe["labels"].apply(
174
  lambda x: list(
175
  set(
 
181
  )
182
  )
183
  )
184
+ else:
185
+ dataframe = dataframe.rename(columns={"labels": "label"})
186
+ dataframe["label"] = dataframe["label"].apply(
187
+ lambda x: x.lower().strip() if x.lower().strip() in labels else None
188
+ )
189
+
190
  progress(1.0, desc="Dataset created")
191
  return dataframe
192
 
 
195
  dataframe: pd.DataFrame,
196
  org_name: str,
197
  repo_name: str,
198
+ multi_label: bool = False,
199
  labels: List[str] = None,
200
  oauth_token: Union[gr.OAuthToken, None] = None,
201
  private: bool = False,
 
207
  progress(0.3, desc="Preprocessing")
208
  labels = get_preprocess_labels(labels)
209
  progress(0.7, desc="Creating dataset")
210
+ if multi_label:
 
 
 
 
 
211
  features = Features(
212
  {
213
  "text": Value("string"),
214
  "labels": Sequence(feature=ClassLabel(names=labels)),
215
  }
216
  )
217
+ else:
218
+ features = Features(
219
+ {"text": Value("string"), "label": ClassLabel(names=labels)}
220
+ )
221
  dataset = Dataset.from_pandas(dataframe, features=features)
222
  dataset = combine_datasets(repo_id, dataset)
223
  distiset = Distiset({"default": dataset})
 
239
  system_prompt: str,
240
  difficulty: str,
241
  clarity: str,
242
+ multi_label: int = 1,
243
  num_rows: int = 10,
244
  labels: List[str] = None,
245
  private: bool = False,
 
252
  system_prompt=system_prompt,
253
  difficulty=difficulty,
254
  clarity=clarity,
255
+ multi_label=multi_label,
256
  labels=labels,
257
  num_rows=num_rows,
258
  temperature=temperature,
 
261
  dataframe,
262
  org_name,
263
  repo_name,
264
+ multi_label,
265
  labels,
266
  oauth_token,
267
  private,
 
288
  ],
289
  questions=[
290
  (
291
+ rg.MultiLabelQuestion(
 
 
 
 
 
 
 
292
  name="labels",
293
  title="Labels",
294
  description="The labels of the conversation",
295
  labels=labels,
296
  )
297
+ if multi_label
298
+ else rg.LabelQuestion(
299
+ name="label",
300
+ title="Label",
301
+ description="The label of the text",
302
+ labels=labels,
303
+ )
304
  ),
305
  ],
306
  metadata=[
 
340
  suggestions=(
341
  [
342
  rg.Suggestion(
343
+ question_name="labels" if multi_label else "label",
344
  value=(
345
+ sample["labels"] if multi_label else sample["label"]
346
  ),
347
  )
348
  ]
349
  if (
350
+ (not multi_label and sample["label"] in labels)
351
  or (
352
+ multi_label
353
  and all(label in labels for label in sample["labels"])
354
  )
355
  )
 
373
  return labels
374
 
375
 
 
 
 
 
376
  def show_pipeline_code_visibility():
377
  return {pipeline_code_ui: gr.Accordion(visible=True)}
378
 
 
430
  multiselect=True,
431
  info="Add the labels to classify the text.",
432
  )
433
+ multi_label = gr.Checkbox(
434
+ label="Multi-label",
435
+ value=False,
 
 
 
436
  interactive=True,
437
+ info="If checked, the text can be classified into multiple labels.",
438
  )
439
  clarity = gr.Dropdown(
440
  choices=[
 
515
  difficulty=difficulty.value,
516
  clarity=clarity.value,
517
  labels=labels.value,
518
+ num_labels=len(labels.value) if multi_label.value else 1,
519
  num_rows=num_rows.value,
520
  temperature=temperature.value,
521
  )
 
532
  show_progress=True,
533
  ).then(
534
  fn=generate_sample_dataset,
535
+ inputs=[system_prompt, difficulty, clarity, labels, multi_label],
536
  outputs=[dataframe],
537
  show_progress=True,
 
 
 
 
 
 
 
 
 
 
538
  )
539
 
540
  btn_apply_to_sample_dataset.click(
541
  fn=generate_sample_dataset,
542
+ inputs=[system_prompt, difficulty, clarity, labels, multi_label],
543
  outputs=[dataframe],
544
  show_progress=True,
545
  )
 
570
  system_prompt,
571
  difficulty,
572
  clarity,
573
+ multi_label,
574
  num_rows,
575
  labels,
576
  private,
 
590
  difficulty,
591
  clarity,
592
  labels,
593
+ multi_label,
594
  num_rows,
595
  temperature,
596
  ],
src/synthetic_dataset_generator/pipelines/textcat.py CHANGED
@@ -29,7 +29,7 @@ Description: DavidMovieHouse is a cinema that has been in business for 10 years.
29
  Output: {"classification_task": "The company DavidMovieHouse is a cinema that has been in business for 10 years and has had customers reviews of varying customer groups. Classify the customer reviews as", "labels": ["positive", "negative"]}
30
 
31
  Description: A dataset that focuses on creating neo-ludite discussions about technologies within the AI space.
32
- Output: {"classification_task": "Neo-ludiite discussions about technologies within the AI space cover from different speaking people . Categorize the discussions into one of the following categories", "labels": ["tech-support", "tech-opposition"]}
33
 
34
  Description: A dataset that covers the articles of a niche sports website called TheSportBlogs that focuses on female sports within the ballsport domain for the US market.
35
  Output: {"classification_task": "TechSportBlogs is a niche sports website that focuses on female sports within the ballsport domain for the US market. Written by different journalists. Determine the category of based on the article using the following categories", "labels": ["basketball", "volleyball", "tennis", "hockey", "baseball", "soccer"]}
@@ -102,7 +102,7 @@ def get_textcat_generator(difficulty, clarity, temperature, is_sample):
102
  return textcat_generator
103
 
104
 
105
- def get_labeller_generator(system_prompt, labels, num_labels):
106
  labeller_generator = TextClassification(
107
  llm=InferenceEndpointsLLM(
108
  model_id=MODEL,
@@ -115,7 +115,7 @@ def get_labeller_generator(system_prompt, labels, num_labels):
115
  ),
116
  context=system_prompt,
117
  available_labels=labels,
118
- n=num_labels,
119
  default_label="unknown",
120
  )
121
  labeller_generator.load()
 
29
  Output: {"classification_task": "The company DavidMovieHouse is a cinema that has been in business for 10 years and has had customers reviews of varying customer groups. Classify the customer reviews as", "labels": ["positive", "negative"]}
30
 
31
  Description: A dataset that focuses on creating neo-ludite discussions about technologies within the AI space.
32
+ Output: {"classification_task": "Neo-ludiite discussions about technologies within the AI space cover from different speaking people. Categorize the discussions into one of the following categories", "labels": ["tech-support", "tech-opposition"]}
33
 
34
  Description: A dataset that covers the articles of a niche sports website called TheSportBlogs that focuses on female sports within the ballsport domain for the US market.
35
  Output: {"classification_task": "TechSportBlogs is a niche sports website that focuses on female sports within the ballsport domain for the US market. Written by different journalists. Determine the category of based on the article using the following categories", "labels": ["basketball", "volleyball", "tennis", "hockey", "baseball", "soccer"]}
 
102
  return textcat_generator
103
 
104
 
105
+ def get_labeller_generator(system_prompt, labels, multi_label):
106
  labeller_generator = TextClassification(
107
  llm=InferenceEndpointsLLM(
108
  model_id=MODEL,
 
115
  ),
116
  context=system_prompt,
117
  available_labels=labels,
118
+ n=len(labels) if multi_label else 1,
119
  default_label="unknown",
120
  )
121
  labeller_generator.load()