Commit
·
2841b26
1
Parent(s):
32d8669
add randomisation of system prompts for generation
Browse files- README.md +1 -1
- examples/ollama_deployment.py +4 -1
- src/synthetic_dataset_generator/apps/base.py +7 -2
- src/synthetic_dataset_generator/apps/chat.py +8 -11
- src/synthetic_dataset_generator/apps/eval.py +3 -1
- src/synthetic_dataset_generator/apps/textcat.py +10 -6
- src/synthetic_dataset_generator/constants.py +1 -1
- src/synthetic_dataset_generator/pipelines/base.py +40 -0
- src/synthetic_dataset_generator/pipelines/chat.py +0 -9
- src/synthetic_dataset_generator/pipelines/textcat.py +0 -2
- src/synthetic_dataset_generator/utils.py +5 -0
README.md
CHANGED
@@ -87,7 +87,7 @@ Optionally, you can use different API providers and models.
|
|
87 |
- `MODEL`: The model to use for generating the dataset, e.g. `meta-llama/Meta-Llama-3.1-8B-Instruct`, `gpt-4o`, `llama3.1`.
|
88 |
- `API_KEY`: The API key to use for the generation API, e.g. `hf_...`, `sk-...`. If not provided, it will default to the provided `HF_TOKEN` environment variable.
|
89 |
- `OPENAI_BASE_URL`: The base URL for any OpenAI compatible API, e.g. `https://api.openai.com/v1/`.
|
90 |
-
- `OLLAMA_BASE_URL`: The base URL for any Ollama compatible API, e.g. `http://127.0.0.1:11434
|
91 |
- `HUGGINGFACE_BASE_URL`: The base URL for any Hugging Face compatible API, e.g. TGI server or Dedicated Inference Endpoints. If you want to use serverless inference, only set the `MODEL`.
|
92 |
|
93 |
SFT and Chat Data generation is only supported with Hugging Face Inference Endpoints , and you can set the following environment variables use it with models other than Llama3 and Qwen2.
|
|
|
87 |
- `MODEL`: The model to use for generating the dataset, e.g. `meta-llama/Meta-Llama-3.1-8B-Instruct`, `gpt-4o`, `llama3.1`.
|
88 |
- `API_KEY`: The API key to use for the generation API, e.g. `hf_...`, `sk-...`. If not provided, it will default to the provided `HF_TOKEN` environment variable.
|
89 |
- `OPENAI_BASE_URL`: The base URL for any OpenAI compatible API, e.g. `https://api.openai.com/v1/`.
|
90 |
+
- `OLLAMA_BASE_URL`: The base URL for any Ollama compatible API, e.g. `http://127.0.0.1:11434/`.
|
91 |
- `HUGGINGFACE_BASE_URL`: The base URL for any Hugging Face compatible API, e.g. TGI server or Dedicated Inference Endpoints. If you want to use serverless inference, only set the `MODEL`.
|
92 |
|
93 |
SFT and Chat Data generation is only supported with Hugging Face Inference Endpoints , and you can set the following environment variables use it with models other than Llama3 and Qwen2.
|
examples/ollama_deployment.py
CHANGED
@@ -9,6 +9,9 @@ assert os.getenv("HF_TOKEN") # push the data to huggingface
|
|
9 |
os.environ["OLLAMA_BASE_URL"] = "http://127.0.0.1:11434/" # ollama base url
|
10 |
os.environ["MODEL"] = "llama3.1:8b-instruct-q8_0" # model id
|
11 |
os.environ["TOKENIZER_ID"] = "meta-llama/Llama-3.1-8B-Instruct" # tokenizer id
|
12 |
-
os.environ["MAGPIE_PRE_QUERY_TEMPLATE"] = "llama3"
|
|
|
|
|
|
|
13 |
|
14 |
launch()
|
|
|
9 |
os.environ["OLLAMA_BASE_URL"] = "http://127.0.0.1:11434/" # ollama base url
|
10 |
os.environ["MODEL"] = "llama3.1:8b-instruct-q8_0" # model id
|
11 |
os.environ["TOKENIZER_ID"] = "meta-llama/Llama-3.1-8B-Instruct" # tokenizer id
|
12 |
+
os.environ["MAGPIE_PRE_QUERY_TEMPLATE"] = "llama3"
|
13 |
+
os.environ["MAX_NUM_ROWS"] = "10000"
|
14 |
+
os.environ["DEFAULT_BATCH_SIZE"] = "5"
|
15 |
+
os.environ["MAX_NUM_TOKENS"] = "1024"
|
16 |
|
17 |
launch()
|
src/synthetic_dataset_generator/apps/base.py
CHANGED
@@ -77,10 +77,15 @@ def validate_push_to_hub(org_name, repo_name):
|
|
77 |
return repo_id
|
78 |
|
79 |
|
80 |
-
def combine_datasets(
|
|
|
|
|
81 |
try:
|
82 |
new_dataset = load_dataset(
|
83 |
-
repo_id,
|
|
|
|
|
|
|
84 |
)
|
85 |
return concatenate_datasets([dataset, new_dataset])
|
86 |
except Exception:
|
|
|
77 |
return repo_id
|
78 |
|
79 |
|
80 |
+
def combine_datasets(
|
81 |
+
repo_id: str, dataset: Dataset, oauth_token: Union[OAuthToken, None]
|
82 |
+
) -> Dataset:
|
83 |
try:
|
84 |
new_dataset = load_dataset(
|
85 |
+
repo_id,
|
86 |
+
split="train",
|
87 |
+
download_mode="force_redownload",
|
88 |
+
token=oauth_token.token,
|
89 |
)
|
90 |
return concatenate_datasets([dataset, new_dataset])
|
91 |
except Exception:
|
src/synthetic_dataset_generator/apps/chat.py
CHANGED
@@ -25,12 +25,12 @@ from synthetic_dataset_generator.constants import (
|
|
25 |
MODEL,
|
26 |
SFT_AVAILABLE,
|
27 |
)
|
|
|
28 |
from synthetic_dataset_generator.pipelines.chat import (
|
29 |
DEFAULT_DATASET_DESCRIPTIONS,
|
30 |
generate_pipeline_code,
|
31 |
get_magpie_generator,
|
32 |
get_prompt_generator,
|
33 |
-
get_prompt_rewriter,
|
34 |
get_response_generator,
|
35 |
)
|
36 |
from synthetic_dataset_generator.pipelines.embeddings import (
|
@@ -40,6 +40,7 @@ from synthetic_dataset_generator.pipelines.embeddings import (
|
|
40 |
from synthetic_dataset_generator.utils import (
|
41 |
get_argilla_client,
|
42 |
get_org_dropdown,
|
|
|
43 |
swap_visibility,
|
44 |
)
|
45 |
|
@@ -106,7 +107,6 @@ def generate_dataset(
|
|
106 |
) -> pd.DataFrame:
|
107 |
num_rows = test_max_num_rows(num_rows)
|
108 |
progress(0.0, desc="(1/2) Generating instructions")
|
109 |
-
prompt_rewriter = get_prompt_rewriter()
|
110 |
magpie_generator = get_magpie_generator(
|
111 |
system_prompt, num_turns, temperature, is_sample
|
112 |
)
|
@@ -117,14 +117,7 @@ def generate_dataset(
|
|
117 |
batch_size = DEFAULT_BATCH_SIZE
|
118 |
|
119 |
# create prompt rewrites
|
120 |
-
|
121 |
-
{
|
122 |
-
"instruction": f"Rewrite this prompt keeping the same structure but highlighting different aspects of the original without adding anything new. Original prompt: {system_prompt} Rewritten prompt: "
|
123 |
-
}
|
124 |
-
for i in range(int(num_rows / 100))
|
125 |
-
]
|
126 |
-
batch = list(prompt_rewriter.process(inputs=inputs))
|
127 |
-
prompt_rewrites = [entry["generation"] for entry in batch[0]] + [system_prompt]
|
128 |
|
129 |
# create instructions
|
130 |
n_processed = 0
|
@@ -142,6 +135,7 @@ def generate_dataset(
|
|
142 |
batch = list(magpie_generator.process(inputs=inputs))
|
143 |
magpie_results.extend(batch[0])
|
144 |
n_processed += batch_size
|
|
|
145 |
progress(0.5, desc="(1/2) Generating instructions")
|
146 |
|
147 |
# generate responses
|
@@ -158,6 +152,7 @@ def generate_dataset(
|
|
158 |
responses = list(response_generator.process(inputs=batch))
|
159 |
response_results.extend(responses[0])
|
160 |
n_processed += batch_size
|
|
|
161 |
for result in response_results:
|
162 |
result["prompt"] = result["instruction"]
|
163 |
result["completion"] = result["generation"]
|
@@ -178,6 +173,7 @@ def generate_dataset(
|
|
178 |
responses = list(response_generator.process(inputs=batch))
|
179 |
response_results.extend(responses[0])
|
180 |
n_processed += batch_size
|
|
|
181 |
for result in response_results:
|
182 |
result["messages"].append(
|
183 |
{"role": "assistant", "content": result["generation"]}
|
@@ -236,7 +232,7 @@ def push_dataset_to_hub(
|
|
236 |
dataframe = convert_dataframe_messages(dataframe)
|
237 |
progress(0.7, desc="Creating dataset")
|
238 |
dataset = Dataset.from_pandas(dataframe)
|
239 |
-
dataset = combine_datasets(repo_id, dataset)
|
240 |
progress(0.9, desc="Pushing dataset")
|
241 |
distiset = Distiset({"default": dataset})
|
242 |
distiset.push_to_hub(
|
@@ -600,4 +596,5 @@ with gr.Blocks() as app:
|
|
600 |
outputs=[dataset_description, system_prompt, num_turns, dataframe],
|
601 |
)
|
602 |
app.load(fn=get_org_dropdown, outputs=[org_name])
|
|
|
603 |
app.load(fn=swap_visibility, outputs=main_ui)
|
|
|
25 |
MODEL,
|
26 |
SFT_AVAILABLE,
|
27 |
)
|
28 |
+
from synthetic_dataset_generator.pipelines.base import get_rewriten_prompts
|
29 |
from synthetic_dataset_generator.pipelines.chat import (
|
30 |
DEFAULT_DATASET_DESCRIPTIONS,
|
31 |
generate_pipeline_code,
|
32 |
get_magpie_generator,
|
33 |
get_prompt_generator,
|
|
|
34 |
get_response_generator,
|
35 |
)
|
36 |
from synthetic_dataset_generator.pipelines.embeddings import (
|
|
|
40 |
from synthetic_dataset_generator.utils import (
|
41 |
get_argilla_client,
|
42 |
get_org_dropdown,
|
43 |
+
get_random_repo_name,
|
44 |
swap_visibility,
|
45 |
)
|
46 |
|
|
|
107 |
) -> pd.DataFrame:
|
108 |
num_rows = test_max_num_rows(num_rows)
|
109 |
progress(0.0, desc="(1/2) Generating instructions")
|
|
|
110 |
magpie_generator = get_magpie_generator(
|
111 |
system_prompt, num_turns, temperature, is_sample
|
112 |
)
|
|
|
117 |
batch_size = DEFAULT_BATCH_SIZE
|
118 |
|
119 |
# create prompt rewrites
|
120 |
+
prompt_rewrites = get_rewriten_prompts(system_prompt, num_rows)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
121 |
|
122 |
# create instructions
|
123 |
n_processed = 0
|
|
|
135 |
batch = list(magpie_generator.process(inputs=inputs))
|
136 |
magpie_results.extend(batch[0])
|
137 |
n_processed += batch_size
|
138 |
+
random.seed(a=random.randint(0, 2**32 - 1))
|
139 |
progress(0.5, desc="(1/2) Generating instructions")
|
140 |
|
141 |
# generate responses
|
|
|
152 |
responses = list(response_generator.process(inputs=batch))
|
153 |
response_results.extend(responses[0])
|
154 |
n_processed += batch_size
|
155 |
+
random.seed(a=random.randint(0, 2**32 - 1))
|
156 |
for result in response_results:
|
157 |
result["prompt"] = result["instruction"]
|
158 |
result["completion"] = result["generation"]
|
|
|
173 |
responses = list(response_generator.process(inputs=batch))
|
174 |
response_results.extend(responses[0])
|
175 |
n_processed += batch_size
|
176 |
+
random.seed(a=random.randint(0, 2**32 - 1))
|
177 |
for result in response_results:
|
178 |
result["messages"].append(
|
179 |
{"role": "assistant", "content": result["generation"]}
|
|
|
232 |
dataframe = convert_dataframe_messages(dataframe)
|
233 |
progress(0.7, desc="Creating dataset")
|
234 |
dataset = Dataset.from_pandas(dataframe)
|
235 |
+
dataset = combine_datasets(repo_id, dataset, oauth_token)
|
236 |
progress(0.9, desc="Pushing dataset")
|
237 |
distiset = Distiset({"default": dataset})
|
238 |
distiset.push_to_hub(
|
|
|
596 |
outputs=[dataset_description, system_prompt, num_turns, dataframe],
|
597 |
)
|
598 |
app.load(fn=get_org_dropdown, outputs=[org_name])
|
599 |
+
app.load(fn=get_random_repo_name, outputs=[repo_name])
|
600 |
app.load(fn=swap_visibility, outputs=main_ui)
|
src/synthetic_dataset_generator/apps/eval.py
CHANGED
@@ -41,6 +41,7 @@ from synthetic_dataset_generator.utils import (
|
|
41 |
extract_column_names,
|
42 |
get_argilla_client,
|
43 |
get_org_dropdown,
|
|
|
44 |
pad_or_truncate_list,
|
45 |
process_columns,
|
46 |
swap_visibility,
|
@@ -359,7 +360,7 @@ def push_dataset_to_hub(
|
|
359 |
):
|
360 |
repo_id = validate_push_to_hub(org_name, repo_name)
|
361 |
dataset = Dataset.from_pandas(dataframe)
|
362 |
-
dataset = combine_datasets(repo_id, dataset)
|
363 |
distiset = Distiset({"default": dataset})
|
364 |
distiset.push_to_hub(
|
365 |
repo_id=repo_id,
|
@@ -907,3 +908,4 @@ with gr.Blocks() as app:
|
|
907 |
|
908 |
app.load(fn=swap_visibility, outputs=main_ui)
|
909 |
app.load(fn=get_org_dropdown, outputs=[org_name])
|
|
|
|
41 |
extract_column_names,
|
42 |
get_argilla_client,
|
43 |
get_org_dropdown,
|
44 |
+
get_random_repo_name,
|
45 |
pad_or_truncate_list,
|
46 |
process_columns,
|
47 |
swap_visibility,
|
|
|
360 |
):
|
361 |
repo_id = validate_push_to_hub(org_name, repo_name)
|
362 |
dataset = Dataset.from_pandas(dataframe)
|
363 |
+
dataset = combine_datasets(repo_id, dataset, oauth_token)
|
364 |
distiset = Distiset({"default": dataset})
|
365 |
distiset.push_to_hub(
|
366 |
repo_id=repo_id,
|
|
|
908 |
|
909 |
app.load(fn=swap_visibility, outputs=main_ui)
|
910 |
app.load(fn=get_org_dropdown, outputs=[org_name])
|
911 |
+
app.load(fn=get_random_repo_name, outputs=[repo_name])
|
src/synthetic_dataset_generator/apps/textcat.py
CHANGED
@@ -20,6 +20,7 @@ from synthetic_dataset_generator.apps.base import (
|
|
20 |
validate_push_to_hub,
|
21 |
)
|
22 |
from synthetic_dataset_generator.constants import DEFAULT_BATCH_SIZE
|
|
|
23 |
from synthetic_dataset_generator.pipelines.embeddings import (
|
24 |
get_embeddings,
|
25 |
get_sentence_embedding_dimensions,
|
@@ -35,6 +36,7 @@ from synthetic_dataset_generator.utils import (
|
|
35 |
get_argilla_client,
|
36 |
get_org_dropdown,
|
37 |
get_preprocess_labels,
|
|
|
38 |
swap_visibility,
|
39 |
)
|
40 |
|
@@ -106,7 +108,7 @@ def generate_dataset(
|
|
106 |
)
|
107 |
updated_system_prompt = f"{system_prompt}. Optional labels: {', '.join(labels)}."
|
108 |
if multi_label:
|
109 |
-
updated_system_prompt = f"{updated_system_prompt}. Only apply relevant labels. Applying less labels is better than applying too many labels."
|
110 |
labeller_generator = get_labeller_generator(
|
111 |
system_prompt=updated_system_prompt,
|
112 |
labels=labels,
|
@@ -118,6 +120,7 @@ def generate_dataset(
|
|
118 |
# create text classification data
|
119 |
n_processed = 0
|
120 |
textcat_results = []
|
|
|
121 |
while n_processed < num_rows:
|
122 |
progress(
|
123 |
2 * 0.5 * n_processed / num_rows,
|
@@ -128,25 +131,24 @@ def generate_dataset(
|
|
128 |
batch_size = min(batch_size, remaining_rows)
|
129 |
inputs = []
|
130 |
for _ in range(batch_size):
|
|
|
131 |
if multi_label:
|
132 |
num_labels = len(labels)
|
133 |
k = int(
|
134 |
random.betavariate(alpha=(num_labels - 1), beta=num_labels)
|
135 |
* num_labels
|
136 |
)
|
137 |
-
else:
|
138 |
-
k = 1
|
139 |
-
|
140 |
sampled_labels = random.sample(labels, min(k, len(labels)))
|
141 |
random.shuffle(sampled_labels)
|
142 |
inputs.append(
|
143 |
{
|
144 |
-
"task": f"{
|
145 |
}
|
146 |
)
|
147 |
batch = list(textcat_generator.process(inputs=inputs))
|
148 |
textcat_results.extend(batch[0])
|
149 |
n_processed += batch_size
|
|
|
150 |
for result in textcat_results:
|
151 |
result["text"] = result["input_text"]
|
152 |
|
@@ -164,6 +166,7 @@ def generate_dataset(
|
|
164 |
labels_batch = list(labeller_generator.process(inputs=batch))
|
165 |
labeller_results.extend(labels_batch[0])
|
166 |
n_processed += batch_size
|
|
|
167 |
progress(
|
168 |
1,
|
169 |
total=total_steps,
|
@@ -250,7 +253,7 @@ def push_dataset_to_hub(
|
|
250 |
dataframe.reset_index(drop=True),
|
251 |
features=features,
|
252 |
)
|
253 |
-
dataset = combine_datasets(repo_id, dataset)
|
254 |
distiset = Distiset({"default": dataset})
|
255 |
progress(0.9, desc="Pushing dataset")
|
256 |
distiset.push_to_hub(
|
@@ -662,3 +665,4 @@ with gr.Blocks() as app:
|
|
662 |
|
663 |
app.load(fn=swap_visibility, outputs=main_ui)
|
664 |
app.load(fn=get_org_dropdown, outputs=[org_name])
|
|
|
|
20 |
validate_push_to_hub,
|
21 |
)
|
22 |
from synthetic_dataset_generator.constants import DEFAULT_BATCH_SIZE
|
23 |
+
from synthetic_dataset_generator.pipelines.base import get_rewriten_prompts
|
24 |
from synthetic_dataset_generator.pipelines.embeddings import (
|
25 |
get_embeddings,
|
26 |
get_sentence_embedding_dimensions,
|
|
|
36 |
get_argilla_client,
|
37 |
get_org_dropdown,
|
38 |
get_preprocess_labels,
|
39 |
+
get_random_repo_name,
|
40 |
swap_visibility,
|
41 |
)
|
42 |
|
|
|
108 |
)
|
109 |
updated_system_prompt = f"{system_prompt}. Optional labels: {', '.join(labels)}."
|
110 |
if multi_label:
|
111 |
+
updated_system_prompt = f"{updated_system_prompt}. Only apply relevant labels. Applying less labels is always better than applying too many labels."
|
112 |
labeller_generator = get_labeller_generator(
|
113 |
system_prompt=updated_system_prompt,
|
114 |
labels=labels,
|
|
|
120 |
# create text classification data
|
121 |
n_processed = 0
|
122 |
textcat_results = []
|
123 |
+
rewritten_system_prompts = get_rewriten_prompts(system_prompt, num_rows)
|
124 |
while n_processed < num_rows:
|
125 |
progress(
|
126 |
2 * 0.5 * n_processed / num_rows,
|
|
|
131 |
batch_size = min(batch_size, remaining_rows)
|
132 |
inputs = []
|
133 |
for _ in range(batch_size):
|
134 |
+
k = 1
|
135 |
if multi_label:
|
136 |
num_labels = len(labels)
|
137 |
k = int(
|
138 |
random.betavariate(alpha=(num_labels - 1), beta=num_labels)
|
139 |
* num_labels
|
140 |
)
|
|
|
|
|
|
|
141 |
sampled_labels = random.sample(labels, min(k, len(labels)))
|
142 |
random.shuffle(sampled_labels)
|
143 |
inputs.append(
|
144 |
{
|
145 |
+
"task": f"{random.choice(rewritten_system_prompts)}. The text represents the following categories: {', '.join(sampled_labels)}"
|
146 |
}
|
147 |
)
|
148 |
batch = list(textcat_generator.process(inputs=inputs))
|
149 |
textcat_results.extend(batch[0])
|
150 |
n_processed += batch_size
|
151 |
+
random.seed(a=random.randint(0, 2**32 - 1))
|
152 |
for result in textcat_results:
|
153 |
result["text"] = result["input_text"]
|
154 |
|
|
|
166 |
labels_batch = list(labeller_generator.process(inputs=batch))
|
167 |
labeller_results.extend(labels_batch[0])
|
168 |
n_processed += batch_size
|
169 |
+
random.seed(a=random.randint(0, 2**32 - 1))
|
170 |
progress(
|
171 |
1,
|
172 |
total=total_steps,
|
|
|
253 |
dataframe.reset_index(drop=True),
|
254 |
features=features,
|
255 |
)
|
256 |
+
dataset = combine_datasets(repo_id, dataset, oauth_token)
|
257 |
distiset = Distiset({"default": dataset})
|
258 |
progress(0.9, desc="Pushing dataset")
|
259 |
distiset.push_to_hub(
|
|
|
665 |
|
666 |
app.load(fn=swap_visibility, outputs=main_ui)
|
667 |
app.load(fn=get_org_dropdown, outputs=[org_name])
|
668 |
+
app.load(fn=get_random_repo_name, outputs=[repo_name])
|
src/synthetic_dataset_generator/constants.py
CHANGED
@@ -81,7 +81,7 @@ if OPENAI_BASE_URL:
|
|
81 |
|
82 |
if not SFT_AVAILABLE:
|
83 |
warnings.warn(
|
84 |
-
"`SFT_AVAILABLE` is set to `False`. Use Hugging Face Inference Endpoints or Ollama to generate chat data, provide a `TOKENIZER_ID` and `MAGPIE_PRE_QUERY_TEMPLATE`."
|
85 |
)
|
86 |
MAGPIE_PRE_QUERY_TEMPLATE = None
|
87 |
|
|
|
81 |
|
82 |
if not SFT_AVAILABLE:
|
83 |
warnings.warn(
|
84 |
+
"`SFT_AVAILABLE` is set to `False`. Use Hugging Face Inference Endpoints or Ollama to generate chat data, provide a `TOKENIZER_ID` and `MAGPIE_PRE_QUERY_TEMPLATE`. You can also use `HUGGINGFACE_BASE_URL` to with vllm."
|
85 |
)
|
86 |
MAGPIE_PRE_QUERY_TEMPLATE = None
|
87 |
|
src/synthetic_dataset_generator/pipelines/base.py
CHANGED
@@ -1,8 +1,13 @@
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
from distilabel.llms import InferenceEndpointsLLM, OllamaLLM, OpenAILLM
|
|
|
3 |
|
4 |
from synthetic_dataset_generator.constants import (
|
5 |
API_KEYS,
|
|
|
6 |
HUGGINGFACE_BASE_URL,
|
7 |
MAGPIE_PRE_QUERY_TEMPLATE,
|
8 |
MODEL,
|
@@ -21,6 +26,41 @@ def _get_next_api_key():
|
|
21 |
return api_key
|
22 |
|
23 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
def _get_llm(use_magpie_template=False, **kwargs):
|
25 |
if OPENAI_BASE_URL:
|
26 |
llm = OpenAILLM(
|
|
|
1 |
+
import math
|
2 |
+
import random
|
3 |
+
|
4 |
import gradio as gr
|
5 |
from distilabel.llms import InferenceEndpointsLLM, OllamaLLM, OpenAILLM
|
6 |
+
from distilabel.steps.tasks import TextGeneration
|
7 |
|
8 |
from synthetic_dataset_generator.constants import (
|
9 |
API_KEYS,
|
10 |
+
DEFAULT_BATCH_SIZE,
|
11 |
HUGGINGFACE_BASE_URL,
|
12 |
MAGPIE_PRE_QUERY_TEMPLATE,
|
13 |
MODEL,
|
|
|
26 |
return api_key
|
27 |
|
28 |
|
29 |
+
def _get_prompt_rewriter():
|
30 |
+
generation_kwargs = {
|
31 |
+
"temperature": 1,
|
32 |
+
}
|
33 |
+
system_prompt = "You are a prompt rewriter. You are given a prompt and you need to rewrite it keeping the same structure but highlighting different aspects of the original without adding anything new."
|
34 |
+
prompt_rewriter = TextGeneration(
|
35 |
+
llm=_get_llm(generation_kwargs=generation_kwargs),
|
36 |
+
system_prompt=system_prompt,
|
37 |
+
use_system_prompt=True,
|
38 |
+
)
|
39 |
+
prompt_rewriter.load()
|
40 |
+
return prompt_rewriter
|
41 |
+
|
42 |
+
|
43 |
+
def get_rewriten_prompts(prompt: str, num_rows: int):
|
44 |
+
prompt_rewriter = _get_prompt_rewriter()
|
45 |
+
# create prompt rewrites
|
46 |
+
inputs = [
|
47 |
+
{"instruction": f"Original prompt: {prompt} \nRewritten prompt: "}
|
48 |
+
for i in range(math.floor(num_rows / 100))
|
49 |
+
]
|
50 |
+
n_processed = 0
|
51 |
+
prompt_rewrites = [prompt]
|
52 |
+
while n_processed < num_rows:
|
53 |
+
batch = list(
|
54 |
+
prompt_rewriter.process(
|
55 |
+
inputs=inputs[n_processed : n_processed + DEFAULT_BATCH_SIZE]
|
56 |
+
)
|
57 |
+
)
|
58 |
+
prompt_rewrites += [entry["generation"] for entry in batch[0]]
|
59 |
+
n_processed += DEFAULT_BATCH_SIZE
|
60 |
+
random.seed(a=random.randint(0, 2**32 - 1))
|
61 |
+
return prompt_rewrites
|
62 |
+
|
63 |
+
|
64 |
def _get_llm(use_magpie_template=False, **kwargs):
|
65 |
if OPENAI_BASE_URL:
|
66 |
llm = OpenAILLM(
|
src/synthetic_dataset_generator/pipelines/chat.py
CHANGED
@@ -203,15 +203,6 @@ def get_magpie_generator(system_prompt, num_turns, temperature, is_sample):
|
|
203 |
return magpie_generator
|
204 |
|
205 |
|
206 |
-
def get_prompt_rewriter():
|
207 |
-
generation_kwargs = {
|
208 |
-
"temperature": 1,
|
209 |
-
}
|
210 |
-
prompt_rewriter = TextGeneration(llm=_get_llm(generation_kwargs=generation_kwargs))
|
211 |
-
prompt_rewriter.load()
|
212 |
-
return prompt_rewriter
|
213 |
-
|
214 |
-
|
215 |
def get_response_generator(system_prompt, num_turns, temperature, is_sample):
|
216 |
if num_turns == 1:
|
217 |
generation_kwargs = {
|
|
|
203 |
return magpie_generator
|
204 |
|
205 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
206 |
def get_response_generator(system_prompt, num_turns, temperature, is_sample):
|
207 |
if num_turns == 1:
|
208 |
generation_kwargs = {
|
src/synthetic_dataset_generator/pipelines/textcat.py
CHANGED
@@ -94,7 +94,6 @@ def get_textcat_generator(difficulty, clarity, temperature, is_sample):
|
|
94 |
"top_p": 0.95,
|
95 |
}
|
96 |
llm = _get_llm(generation_kwargs=generation_kwargs)
|
97 |
-
|
98 |
textcat_generator = GenerateTextClassificationData(
|
99 |
llm=llm,
|
100 |
difficulty=None if difficulty == "mixed" else difficulty,
|
@@ -111,7 +110,6 @@ def get_labeller_generator(system_prompt, labels, multi_label):
|
|
111 |
"max_new_tokens": MAX_NUM_TOKENS,
|
112 |
}
|
113 |
llm = _get_llm(generation_kwargs=generation_kwargs)
|
114 |
-
|
115 |
labeller_generator = TextClassification(
|
116 |
llm=llm,
|
117 |
context=system_prompt,
|
|
|
94 |
"top_p": 0.95,
|
95 |
}
|
96 |
llm = _get_llm(generation_kwargs=generation_kwargs)
|
|
|
97 |
textcat_generator = GenerateTextClassificationData(
|
98 |
llm=llm,
|
99 |
difficulty=None if difficulty == "mixed" else difficulty,
|
|
|
110 |
"max_new_tokens": MAX_NUM_TOKENS,
|
111 |
}
|
112 |
llm = _get_llm(generation_kwargs=generation_kwargs)
|
|
|
113 |
labeller_generator = TextClassification(
|
114 |
llm=llm,
|
115 |
context=system_prompt,
|
src/synthetic_dataset_generator/utils.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
import json
|
|
|
2 |
import warnings
|
3 |
from typing import List, Optional, Union
|
4 |
|
@@ -55,6 +56,10 @@ def list_orgs(oauth_token: Union[OAuthToken, None] = None):
|
|
55 |
return organizations
|
56 |
|
57 |
|
|
|
|
|
|
|
|
|
58 |
def get_org_dropdown(oauth_token: Union[OAuthToken, None] = None):
|
59 |
if oauth_token is not None:
|
60 |
orgs = list_orgs(oauth_token)
|
|
|
1 |
import json
|
2 |
+
import uuid
|
3 |
import warnings
|
4 |
from typing import List, Optional, Union
|
5 |
|
|
|
56 |
return organizations
|
57 |
|
58 |
|
59 |
+
def get_random_repo_name():
|
60 |
+
return f"my-distiset-{str(uuid.uuid4())[:8]}"
|
61 |
+
|
62 |
+
|
63 |
def get_org_dropdown(oauth_token: Union[OAuthToken, None] = None):
|
64 |
if oauth_token is not None:
|
65 |
orgs = list_orgs(oauth_token)
|