Commit
·
e044b6a
1
Parent(s):
4983843
add label randomification perf created sample
Browse files
src/synthetic_dataset_generator/apps/textcat.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
import json
|
|
|
2 |
import uuid
|
3 |
from typing import List, Union
|
4 |
|
@@ -11,6 +12,7 @@ from huggingface_hub import HfApi
|
|
11 |
|
12 |
from src.synthetic_dataset_generator.apps.base import (
|
13 |
hide_success_message,
|
|
|
14 |
show_success_message,
|
15 |
validate_argilla_user_workspace_dataset,
|
16 |
validate_push_to_hub,
|
@@ -119,9 +121,17 @@ def generate_dataset(
|
|
119 |
)
|
120 |
remaining_rows = num_rows - n_processed
|
121 |
batch_size = min(batch_size, remaining_rows)
|
122 |
-
inputs = [
|
123 |
-
|
124 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
125 |
batch = list(textcat_generator.process(inputs=inputs))
|
126 |
textcat_results.extend(batch[0])
|
127 |
n_processed += batch_size
|
@@ -160,6 +170,18 @@ def generate_dataset(
|
|
160 |
dataframe["label"] = dataframe["label"].apply(
|
161 |
lambda x: x.lower().strip() if x.lower().strip() in labels else None
|
162 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
163 |
progress(1.0, desc="Dataset generation completed")
|
164 |
return dataframe
|
165 |
|
@@ -172,6 +194,7 @@ def push_dataset_to_hub(
|
|
172 |
labels: List[str] = None,
|
173 |
oauth_token: Union[gr.OAuthToken, None] = None,
|
174 |
private: bool = False,
|
|
|
175 |
):
|
176 |
repo_id = validate_push_to_hub(org_name, repo_name)
|
177 |
labels = get_preprocess_labels(labels)
|
@@ -195,6 +218,7 @@ def push_dataset_to_hub(
|
|
195 |
token=oauth_token.token,
|
196 |
create_pr=False,
|
197 |
)
|
|
|
198 |
|
199 |
|
200 |
def push_dataset(
|
@@ -208,6 +232,7 @@ def push_dataset(
|
|
208 |
labels: List[str] = None,
|
209 |
private: bool = False,
|
210 |
temperature: float = 0.8,
|
|
|
211 |
oauth_token: Union[gr.OAuthToken, None] = None,
|
212 |
progress=gr.Progress(),
|
213 |
) -> pd.DataFrame:
|
@@ -221,7 +246,14 @@ def push_dataset(
|
|
221 |
temperature=temperature,
|
222 |
)
|
223 |
push_dataset_to_hub(
|
224 |
-
dataframe,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
225 |
)
|
226 |
|
227 |
dataframe = dataframe[
|
@@ -407,7 +439,7 @@ with gr.Blocks() as app:
|
|
407 |
("Ambiguous", "ambiguous"),
|
408 |
("Mixed", "mixed"),
|
409 |
],
|
410 |
-
value="
|
411 |
label="Clarity",
|
412 |
info="Set how easily the correct label or labels can be identified.",
|
413 |
interactive=True,
|
@@ -419,7 +451,7 @@ with gr.Blocks() as app:
|
|
419 |
("PhD", "PhD"),
|
420 |
("Mixed", "mixed"),
|
421 |
],
|
422 |
-
value="
|
423 |
label="Difficulty",
|
424 |
info="Select the comprehension level for the text. Ensure it matches the task context.",
|
425 |
interactive=True,
|
@@ -544,6 +576,7 @@ with gr.Blocks() as app:
|
|
544 |
labels,
|
545 |
private,
|
546 |
temperature,
|
|
|
547 |
],
|
548 |
outputs=[success_message],
|
549 |
show_progress=True,
|
|
|
1 |
import json
|
2 |
+
import random
|
3 |
import uuid
|
4 |
from typing import List, Union
|
5 |
|
|
|
12 |
|
13 |
from src.synthetic_dataset_generator.apps.base import (
|
14 |
hide_success_message,
|
15 |
+
push_pipeline_code_to_hub,
|
16 |
show_success_message,
|
17 |
validate_argilla_user_workspace_dataset,
|
18 |
validate_push_to_hub,
|
|
|
121 |
)
|
122 |
remaining_rows = num_rows - n_processed
|
123 |
batch_size = min(batch_size, remaining_rows)
|
124 |
+
inputs = []
|
125 |
+
for _ in range(batch_size):
|
126 |
+
if num_labels == 1:
|
127 |
+
num_labels = 1
|
128 |
+
else:
|
129 |
+
num_labels = int(random.gammavariate(2, 2) * num_labels)
|
130 |
+
sampled_labels = random.sample(labels, num_labels)
|
131 |
+
random.shuffle(sampled_labels)
|
132 |
+
inputs.append(
|
133 |
+
{"task": f"{system_prompt}. Labels: {', '.join(sampled_labels)}"}
|
134 |
+
)
|
135 |
batch = list(textcat_generator.process(inputs=inputs))
|
136 |
textcat_results.extend(batch[0])
|
137 |
n_processed += batch_size
|
|
|
170 |
dataframe["label"] = dataframe["label"].apply(
|
171 |
lambda x: x.lower().strip() if x.lower().strip() in labels else None
|
172 |
)
|
173 |
+
else:
|
174 |
+
dataframe["labels"] = dataframe["labels"].apply(
|
175 |
+
lambda x: list(
|
176 |
+
set(
|
177 |
+
[
|
178 |
+
label.lower().strip()
|
179 |
+
for label in x
|
180 |
+
if label.lower().strip() in labels
|
181 |
+
]
|
182 |
+
)
|
183 |
+
)
|
184 |
+
)
|
185 |
progress(1.0, desc="Dataset generation completed")
|
186 |
return dataframe
|
187 |
|
|
|
194 |
labels: List[str] = None,
|
195 |
oauth_token: Union[gr.OAuthToken, None] = None,
|
196 |
private: bool = False,
|
197 |
+
pipeline_code: str = "",
|
198 |
):
|
199 |
repo_id = validate_push_to_hub(org_name, repo_name)
|
200 |
labels = get_preprocess_labels(labels)
|
|
|
218 |
token=oauth_token.token,
|
219 |
create_pr=False,
|
220 |
)
|
221 |
+
push_pipeline_code_to_hub(pipeline_code, org_name, repo_name, oauth_token)
|
222 |
|
223 |
|
224 |
def push_dataset(
|
|
|
232 |
labels: List[str] = None,
|
233 |
private: bool = False,
|
234 |
temperature: float = 0.8,
|
235 |
+
pipeline_code: str = "",
|
236 |
oauth_token: Union[gr.OAuthToken, None] = None,
|
237 |
progress=gr.Progress(),
|
238 |
) -> pd.DataFrame:
|
|
|
246 |
temperature=temperature,
|
247 |
)
|
248 |
push_dataset_to_hub(
|
249 |
+
dataframe,
|
250 |
+
org_name,
|
251 |
+
repo_name,
|
252 |
+
num_labels,
|
253 |
+
labels,
|
254 |
+
oauth_token,
|
255 |
+
private,
|
256 |
+
pipeline_code,
|
257 |
)
|
258 |
|
259 |
dataframe = dataframe[
|
|
|
439 |
("Ambiguous", "ambiguous"),
|
440 |
("Mixed", "mixed"),
|
441 |
],
|
442 |
+
value="understandable with some effort",
|
443 |
label="Clarity",
|
444 |
info="Set how easily the correct label or labels can be identified.",
|
445 |
interactive=True,
|
|
|
451 |
("PhD", "PhD"),
|
452 |
("Mixed", "mixed"),
|
453 |
],
|
454 |
+
value="high school",
|
455 |
label="Difficulty",
|
456 |
info="Select the comprehension level for the text. Ensure it matches the task context.",
|
457 |
interactive=True,
|
|
|
576 |
labels,
|
577 |
private,
|
578 |
temperature,
|
579 |
+
pipeline_code,
|
580 |
],
|
581 |
outputs=[success_message],
|
582 |
show_progress=True,
|
src/synthetic_dataset_generator/pipelines/textcat.py
CHANGED
@@ -15,35 +15,29 @@ from synthetic_dataset_generator.utils import get_preprocess_labels
|
|
15 |
|
16 |
PROMPT_CREATION_PROMPT = """You are an AI assistant specialized in generating very precise text classification tasks for dataset creation.
|
17 |
|
18 |
-
Your
|
19 |
|
20 |
-
The prompt
|
21 |
|
22 |
-
|
23 |
-
|
24 |
-
{"classification_task": "Classify the following customer review of a cinema as", "labels": ["positive", "negative"]}
|
25 |
-
|
26 |
-
{"classification_task": "Categorize the following news article into one or more of the following categories:", "labels": ["politics", "sports", "technology", "entertainment", "health", "business", "environment", "education", "science", "international"]}
|
27 |
-
|
28 |
-
{"classification_task": "Classify the following news article into one or more of the following categories:", "labels": ['politics', 'sports', 'technology', 'entertainment', 'health', 'business', 'environment', 'education', 'science', 'international']}
|
29 |
|
30 |
-
|
31 |
|
32 |
-
|
33 |
-
|
34 |
-
{"classification_task": "Classify the following movie review into one of the following categories:", "labels": ['critical', 'praise', 'disappointed', 'enthusiastic']}
|
35 |
-
|
36 |
-
{"classification_task": "Categorize the following customer service transcript into one of the following categories:", "labels": ['satisfied', 'dissatisfied', 'highly-satisfied', 'somewhat-dissatisfied', 'indifferent']}
|
37 |
|
38 |
-
|
|
|
39 |
|
40 |
-
|
|
|
41 |
|
42 |
-
|
|
|
43 |
|
44 |
-
|
|
|
45 |
|
46 |
-
|
47 |
"""
|
48 |
|
49 |
DEFAULT_DATASET_DESCRIPTIONS = [
|
@@ -66,6 +60,19 @@ class TextClassificationTask(BaseModel):
|
|
66 |
)
|
67 |
|
68 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
69 |
def get_prompt_generator():
|
70 |
prompt_generator = TextGeneration(
|
71 |
llm=InferenceEndpointsLLM(
|
|
|
15 |
|
16 |
PROMPT_CREATION_PROMPT = """You are an AI assistant specialized in generating very precise text classification tasks for dataset creation.
|
17 |
|
18 |
+
Your should write a prompt following a the dataset description. Respond with the prompt and nothing else.
|
19 |
|
20 |
+
The prompt should follow the same style and structure as the following example prompts, clearly specifying the possible classification labels.
|
21 |
|
22 |
+
Make sure to always include all of the detailed information from the description and the context of the company that is provided.
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
|
24 |
+
Don't include the labels in the classification_task but only provide a high level description of the classification task.
|
25 |
|
26 |
+
If a label is composed of multiple words, use a hyphen to separate them. For example, 'smartphone-review', 'customer-service', 'product-quality'.:
|
|
|
|
|
|
|
|
|
27 |
|
28 |
+
Description: DavidMovieHouse is a cinema that has been in business for 10 years.
|
29 |
+
Output: {"classification_task": "The company DavidMovieHouse is a cinema that has been in business for 10 years and has had customers reviews. Classify the customer reviews as", "labels": ["positive", "negative"]}
|
30 |
|
31 |
+
Description: A dataset that focuses on creating neo-ludite discussions about technologies within the AI space.
|
32 |
+
Output: {"classification_task": "Neo-ludiite discussions about technologies within the AI space cover. Categorize the discussions into one of the following categories", "labels": ["tech-support", "tech-opposition"]}
|
33 |
|
34 |
+
Description: A dataset that covers the articles of a niche sports website called TheSportBlogs that focuses on female sports within the ballsport domain for the US market.
|
35 |
+
Output: {"classification_task": "TechSportBlogs is a niche sports website that focuses on female sports within the ballsport domain for the US market. Determine the category of based on the article using the following categories", "labels": ["basketball", "volleyball", "tennis", "hockey", "baseball", "soccer"]}
|
36 |
|
37 |
+
Description: A dataset covering customer reviews for an e-commerce website called Argilla that sells technology datasets within the open source Natural Language Processing space and has review with labels "data-quality", "data-accuracy", "customer-service", "price", "product-availability", "shipping-speed"
|
38 |
+
Output: {"classification_task": "A dataset covering customer reviews for an e-commerce website called Argilla that sells technology datasets within the open source Natural Language Processing space and has review with labels", "labels": ["data-quality", "data-accuracy", "customer-service", "price", "product-availability", "shipping-speed"]}
|
39 |
|
40 |
+
Description:
|
41 |
"""
|
42 |
|
43 |
DEFAULT_DATASET_DESCRIPTIONS = [
|
|
|
60 |
)
|
61 |
|
62 |
|
63 |
+
class DatasetDescription(BaseModel):
|
64 |
+
description: str = Field(
|
65 |
+
...,
|
66 |
+
title="description",
|
67 |
+
description="The description of the dataset.",
|
68 |
+
)
|
69 |
+
labels: list[str] = Field(
|
70 |
+
...,
|
71 |
+
title="labels",
|
72 |
+
description="The possible labels for the classification task.",
|
73 |
+
)
|
74 |
+
|
75 |
+
|
76 |
def get_prompt_generator():
|
77 |
prompt_generator = TextGeneration(
|
78 |
llm=InferenceEndpointsLLM(
|