davidberenstein1957 HF staff commited on
Commit
2841b26
·
1 Parent(s): 32d8669

add randomisation of system prompts for generation

Browse files
README.md CHANGED
@@ -87,7 +87,7 @@ Optionally, you can use different API providers and models.
87
  - `MODEL`: The model to use for generating the dataset, e.g. `meta-llama/Meta-Llama-3.1-8B-Instruct`, `gpt-4o`, `llama3.1`.
88
  - `API_KEY`: The API key to use for the generation API, e.g. `hf_...`, `sk-...`. If not provided, it will default to the provided `HF_TOKEN` environment variable.
89
  - `OPENAI_BASE_URL`: The base URL for any OpenAI compatible API, e.g. `https://api.openai.com/v1/`.
90
- - `OLLAMA_BASE_URL`: The base URL for any Ollama compatible API, e.g. `http://127.0.0.1:11434/v1/`.
91
  - `HUGGINGFACE_BASE_URL`: The base URL for any Hugging Face compatible API, e.g. TGI server or Dedicated Inference Endpoints. If you want to use serverless inference, only set the `MODEL`.
92
 
93
  SFT and Chat Data generation is only supported with Hugging Face Inference Endpoints , and you can set the following environment variables use it with models other than Llama3 and Qwen2.
 
87
  - `MODEL`: The model to use for generating the dataset, e.g. `meta-llama/Meta-Llama-3.1-8B-Instruct`, `gpt-4o`, `llama3.1`.
88
  - `API_KEY`: The API key to use for the generation API, e.g. `hf_...`, `sk-...`. If not provided, it will default to the provided `HF_TOKEN` environment variable.
89
  - `OPENAI_BASE_URL`: The base URL for any OpenAI compatible API, e.g. `https://api.openai.com/v1/`.
90
+ - `OLLAMA_BASE_URL`: The base URL for any Ollama compatible API, e.g. `http://127.0.0.1:11434/`.
91
  - `HUGGINGFACE_BASE_URL`: The base URL for any Hugging Face compatible API, e.g. TGI server or Dedicated Inference Endpoints. If you want to use serverless inference, only set the `MODEL`.
92
 
93
  SFT and Chat Data generation is only supported with Hugging Face Inference Endpoints , and you can set the following environment variables use it with models other than Llama3 and Qwen2.
examples/ollama_deployment.py CHANGED
@@ -9,6 +9,9 @@ assert os.getenv("HF_TOKEN") # push the data to huggingface
9
  os.environ["OLLAMA_BASE_URL"] = "http://127.0.0.1:11434/" # ollama base url
10
  os.environ["MODEL"] = "llama3.1:8b-instruct-q8_0" # model id
11
  os.environ["TOKENIZER_ID"] = "meta-llama/Llama-3.1-8B-Instruct" # tokenizer id
12
- os.environ["MAGPIE_PRE_QUERY_TEMPLATE"] = "llama3" # magpie template
 
 
 
13
 
14
  launch()
 
9
  os.environ["OLLAMA_BASE_URL"] = "http://127.0.0.1:11434/" # ollama base url
10
  os.environ["MODEL"] = "llama3.1:8b-instruct-q8_0" # model id
11
  os.environ["TOKENIZER_ID"] = "meta-llama/Llama-3.1-8B-Instruct" # tokenizer id
12
+ os.environ["MAGPIE_PRE_QUERY_TEMPLATE"] = "llama3"
13
+ os.environ["MAX_NUM_ROWS"] = "10000"
14
+ os.environ["DEFAULT_BATCH_SIZE"] = "5"
15
+ os.environ["MAX_NUM_TOKENS"] = "1024"
16
 
17
  launch()
src/synthetic_dataset_generator/apps/base.py CHANGED
@@ -77,10 +77,15 @@ def validate_push_to_hub(org_name, repo_name):
77
  return repo_id
78
 
79
 
80
- def combine_datasets(repo_id: str, dataset: Dataset) -> Dataset:
 
 
81
  try:
82
  new_dataset = load_dataset(
83
- repo_id, split="train", download_mode="force_redownload"
 
 
 
84
  )
85
  return concatenate_datasets([dataset, new_dataset])
86
  except Exception:
 
77
  return repo_id
78
 
79
 
80
+ def combine_datasets(
81
+ repo_id: str, dataset: Dataset, oauth_token: Union[OAuthToken, None]
82
+ ) -> Dataset:
83
  try:
84
  new_dataset = load_dataset(
85
+ repo_id,
86
+ split="train",
87
+ download_mode="force_redownload",
88
+ token=oauth_token.token,
89
  )
90
  return concatenate_datasets([dataset, new_dataset])
91
  except Exception:
src/synthetic_dataset_generator/apps/chat.py CHANGED
@@ -25,12 +25,12 @@ from synthetic_dataset_generator.constants import (
25
  MODEL,
26
  SFT_AVAILABLE,
27
  )
 
28
  from synthetic_dataset_generator.pipelines.chat import (
29
  DEFAULT_DATASET_DESCRIPTIONS,
30
  generate_pipeline_code,
31
  get_magpie_generator,
32
  get_prompt_generator,
33
- get_prompt_rewriter,
34
  get_response_generator,
35
  )
36
  from synthetic_dataset_generator.pipelines.embeddings import (
@@ -40,6 +40,7 @@ from synthetic_dataset_generator.pipelines.embeddings import (
40
  from synthetic_dataset_generator.utils import (
41
  get_argilla_client,
42
  get_org_dropdown,
 
43
  swap_visibility,
44
  )
45
 
@@ -106,7 +107,6 @@ def generate_dataset(
106
  ) -> pd.DataFrame:
107
  num_rows = test_max_num_rows(num_rows)
108
  progress(0.0, desc="(1/2) Generating instructions")
109
- prompt_rewriter = get_prompt_rewriter()
110
  magpie_generator = get_magpie_generator(
111
  system_prompt, num_turns, temperature, is_sample
112
  )
@@ -117,14 +117,7 @@ def generate_dataset(
117
  batch_size = DEFAULT_BATCH_SIZE
118
 
119
  # create prompt rewrites
120
- inputs = [
121
- {
122
- "instruction": f"Rewrite this prompt keeping the same structure but highlighting different aspects of the original without adding anything new. Original prompt: {system_prompt} Rewritten prompt: "
123
- }
124
- for i in range(int(num_rows / 100))
125
- ]
126
- batch = list(prompt_rewriter.process(inputs=inputs))
127
- prompt_rewrites = [entry["generation"] for entry in batch[0]] + [system_prompt]
128
 
129
  # create instructions
130
  n_processed = 0
@@ -142,6 +135,7 @@ def generate_dataset(
142
  batch = list(magpie_generator.process(inputs=inputs))
143
  magpie_results.extend(batch[0])
144
  n_processed += batch_size
 
145
  progress(0.5, desc="(1/2) Generating instructions")
146
 
147
  # generate responses
@@ -158,6 +152,7 @@ def generate_dataset(
158
  responses = list(response_generator.process(inputs=batch))
159
  response_results.extend(responses[0])
160
  n_processed += batch_size
 
161
  for result in response_results:
162
  result["prompt"] = result["instruction"]
163
  result["completion"] = result["generation"]
@@ -178,6 +173,7 @@ def generate_dataset(
178
  responses = list(response_generator.process(inputs=batch))
179
  response_results.extend(responses[0])
180
  n_processed += batch_size
 
181
  for result in response_results:
182
  result["messages"].append(
183
  {"role": "assistant", "content": result["generation"]}
@@ -236,7 +232,7 @@ def push_dataset_to_hub(
236
  dataframe = convert_dataframe_messages(dataframe)
237
  progress(0.7, desc="Creating dataset")
238
  dataset = Dataset.from_pandas(dataframe)
239
- dataset = combine_datasets(repo_id, dataset)
240
  progress(0.9, desc="Pushing dataset")
241
  distiset = Distiset({"default": dataset})
242
  distiset.push_to_hub(
@@ -600,4 +596,5 @@ with gr.Blocks() as app:
600
  outputs=[dataset_description, system_prompt, num_turns, dataframe],
601
  )
602
  app.load(fn=get_org_dropdown, outputs=[org_name])
 
603
  app.load(fn=swap_visibility, outputs=main_ui)
 
25
  MODEL,
26
  SFT_AVAILABLE,
27
  )
28
+ from synthetic_dataset_generator.pipelines.base import get_rewriten_prompts
29
  from synthetic_dataset_generator.pipelines.chat import (
30
  DEFAULT_DATASET_DESCRIPTIONS,
31
  generate_pipeline_code,
32
  get_magpie_generator,
33
  get_prompt_generator,
 
34
  get_response_generator,
35
  )
36
  from synthetic_dataset_generator.pipelines.embeddings import (
 
40
  from synthetic_dataset_generator.utils import (
41
  get_argilla_client,
42
  get_org_dropdown,
43
+ get_random_repo_name,
44
  swap_visibility,
45
  )
46
 
 
107
  ) -> pd.DataFrame:
108
  num_rows = test_max_num_rows(num_rows)
109
  progress(0.0, desc="(1/2) Generating instructions")
 
110
  magpie_generator = get_magpie_generator(
111
  system_prompt, num_turns, temperature, is_sample
112
  )
 
117
  batch_size = DEFAULT_BATCH_SIZE
118
 
119
  # create prompt rewrites
120
+ prompt_rewrites = get_rewriten_prompts(system_prompt, num_rows)
 
 
 
 
 
 
 
121
 
122
  # create instructions
123
  n_processed = 0
 
135
  batch = list(magpie_generator.process(inputs=inputs))
136
  magpie_results.extend(batch[0])
137
  n_processed += batch_size
138
+ random.seed(a=random.randint(0, 2**32 - 1))
139
  progress(0.5, desc="(1/2) Generating instructions")
140
 
141
  # generate responses
 
152
  responses = list(response_generator.process(inputs=batch))
153
  response_results.extend(responses[0])
154
  n_processed += batch_size
155
+ random.seed(a=random.randint(0, 2**32 - 1))
156
  for result in response_results:
157
  result["prompt"] = result["instruction"]
158
  result["completion"] = result["generation"]
 
173
  responses = list(response_generator.process(inputs=batch))
174
  response_results.extend(responses[0])
175
  n_processed += batch_size
176
+ random.seed(a=random.randint(0, 2**32 - 1))
177
  for result in response_results:
178
  result["messages"].append(
179
  {"role": "assistant", "content": result["generation"]}
 
232
  dataframe = convert_dataframe_messages(dataframe)
233
  progress(0.7, desc="Creating dataset")
234
  dataset = Dataset.from_pandas(dataframe)
235
+ dataset = combine_datasets(repo_id, dataset, oauth_token)
236
  progress(0.9, desc="Pushing dataset")
237
  distiset = Distiset({"default": dataset})
238
  distiset.push_to_hub(
 
596
  outputs=[dataset_description, system_prompt, num_turns, dataframe],
597
  )
598
  app.load(fn=get_org_dropdown, outputs=[org_name])
599
+ app.load(fn=get_random_repo_name, outputs=[repo_name])
600
  app.load(fn=swap_visibility, outputs=main_ui)
src/synthetic_dataset_generator/apps/eval.py CHANGED
@@ -41,6 +41,7 @@ from synthetic_dataset_generator.utils import (
41
  extract_column_names,
42
  get_argilla_client,
43
  get_org_dropdown,
 
44
  pad_or_truncate_list,
45
  process_columns,
46
  swap_visibility,
@@ -359,7 +360,7 @@ def push_dataset_to_hub(
359
  ):
360
  repo_id = validate_push_to_hub(org_name, repo_name)
361
  dataset = Dataset.from_pandas(dataframe)
362
- dataset = combine_datasets(repo_id, dataset)
363
  distiset = Distiset({"default": dataset})
364
  distiset.push_to_hub(
365
  repo_id=repo_id,
@@ -907,3 +908,4 @@ with gr.Blocks() as app:
907
 
908
  app.load(fn=swap_visibility, outputs=main_ui)
909
  app.load(fn=get_org_dropdown, outputs=[org_name])
 
 
41
  extract_column_names,
42
  get_argilla_client,
43
  get_org_dropdown,
44
+ get_random_repo_name,
45
  pad_or_truncate_list,
46
  process_columns,
47
  swap_visibility,
 
360
  ):
361
  repo_id = validate_push_to_hub(org_name, repo_name)
362
  dataset = Dataset.from_pandas(dataframe)
363
+ dataset = combine_datasets(repo_id, dataset, oauth_token)
364
  distiset = Distiset({"default": dataset})
365
  distiset.push_to_hub(
366
  repo_id=repo_id,
 
908
 
909
  app.load(fn=swap_visibility, outputs=main_ui)
910
  app.load(fn=get_org_dropdown, outputs=[org_name])
911
+ app.load(fn=get_random_repo_name, outputs=[repo_name])
src/synthetic_dataset_generator/apps/textcat.py CHANGED
@@ -20,6 +20,7 @@ from synthetic_dataset_generator.apps.base import (
20
  validate_push_to_hub,
21
  )
22
  from synthetic_dataset_generator.constants import DEFAULT_BATCH_SIZE
 
23
  from synthetic_dataset_generator.pipelines.embeddings import (
24
  get_embeddings,
25
  get_sentence_embedding_dimensions,
@@ -35,6 +36,7 @@ from synthetic_dataset_generator.utils import (
35
  get_argilla_client,
36
  get_org_dropdown,
37
  get_preprocess_labels,
 
38
  swap_visibility,
39
  )
40
 
@@ -106,7 +108,7 @@ def generate_dataset(
106
  )
107
  updated_system_prompt = f"{system_prompt}. Optional labels: {', '.join(labels)}."
108
  if multi_label:
109
- updated_system_prompt = f"{updated_system_prompt}. Only apply relevant labels. Applying less labels is better than applying too many labels."
110
  labeller_generator = get_labeller_generator(
111
  system_prompt=updated_system_prompt,
112
  labels=labels,
@@ -118,6 +120,7 @@ def generate_dataset(
118
  # create text classification data
119
  n_processed = 0
120
  textcat_results = []
 
121
  while n_processed < num_rows:
122
  progress(
123
  2 * 0.5 * n_processed / num_rows,
@@ -128,25 +131,24 @@ def generate_dataset(
128
  batch_size = min(batch_size, remaining_rows)
129
  inputs = []
130
  for _ in range(batch_size):
 
131
  if multi_label:
132
  num_labels = len(labels)
133
  k = int(
134
  random.betavariate(alpha=(num_labels - 1), beta=num_labels)
135
  * num_labels
136
  )
137
- else:
138
- k = 1
139
-
140
  sampled_labels = random.sample(labels, min(k, len(labels)))
141
  random.shuffle(sampled_labels)
142
  inputs.append(
143
  {
144
- "task": f"{system_prompt}. The text represents the following categories: {', '.join(sampled_labels)}"
145
  }
146
  )
147
  batch = list(textcat_generator.process(inputs=inputs))
148
  textcat_results.extend(batch[0])
149
  n_processed += batch_size
 
150
  for result in textcat_results:
151
  result["text"] = result["input_text"]
152
 
@@ -164,6 +166,7 @@ def generate_dataset(
164
  labels_batch = list(labeller_generator.process(inputs=batch))
165
  labeller_results.extend(labels_batch[0])
166
  n_processed += batch_size
 
167
  progress(
168
  1,
169
  total=total_steps,
@@ -250,7 +253,7 @@ def push_dataset_to_hub(
250
  dataframe.reset_index(drop=True),
251
  features=features,
252
  )
253
- dataset = combine_datasets(repo_id, dataset)
254
  distiset = Distiset({"default": dataset})
255
  progress(0.9, desc="Pushing dataset")
256
  distiset.push_to_hub(
@@ -662,3 +665,4 @@ with gr.Blocks() as app:
662
 
663
  app.load(fn=swap_visibility, outputs=main_ui)
664
  app.load(fn=get_org_dropdown, outputs=[org_name])
 
 
20
  validate_push_to_hub,
21
  )
22
  from synthetic_dataset_generator.constants import DEFAULT_BATCH_SIZE
23
+ from synthetic_dataset_generator.pipelines.base import get_rewriten_prompts
24
  from synthetic_dataset_generator.pipelines.embeddings import (
25
  get_embeddings,
26
  get_sentence_embedding_dimensions,
 
36
  get_argilla_client,
37
  get_org_dropdown,
38
  get_preprocess_labels,
39
+ get_random_repo_name,
40
  swap_visibility,
41
  )
42
 
 
108
  )
109
  updated_system_prompt = f"{system_prompt}. Optional labels: {', '.join(labels)}."
110
  if multi_label:
111
+ updated_system_prompt = f"{updated_system_prompt}. Only apply relevant labels. Applying less labels is always better than applying too many labels."
112
  labeller_generator = get_labeller_generator(
113
  system_prompt=updated_system_prompt,
114
  labels=labels,
 
120
  # create text classification data
121
  n_processed = 0
122
  textcat_results = []
123
+ rewritten_system_prompts = get_rewriten_prompts(system_prompt, num_rows)
124
  while n_processed < num_rows:
125
  progress(
126
  2 * 0.5 * n_processed / num_rows,
 
131
  batch_size = min(batch_size, remaining_rows)
132
  inputs = []
133
  for _ in range(batch_size):
134
+ k = 1
135
  if multi_label:
136
  num_labels = len(labels)
137
  k = int(
138
  random.betavariate(alpha=(num_labels - 1), beta=num_labels)
139
  * num_labels
140
  )
 
 
 
141
  sampled_labels = random.sample(labels, min(k, len(labels)))
142
  random.shuffle(sampled_labels)
143
  inputs.append(
144
  {
145
+ "task": f"{random.choice(rewritten_system_prompts)}. The text represents the following categories: {', '.join(sampled_labels)}"
146
  }
147
  )
148
  batch = list(textcat_generator.process(inputs=inputs))
149
  textcat_results.extend(batch[0])
150
  n_processed += batch_size
151
+ random.seed(a=random.randint(0, 2**32 - 1))
152
  for result in textcat_results:
153
  result["text"] = result["input_text"]
154
 
 
166
  labels_batch = list(labeller_generator.process(inputs=batch))
167
  labeller_results.extend(labels_batch[0])
168
  n_processed += batch_size
169
+ random.seed(a=random.randint(0, 2**32 - 1))
170
  progress(
171
  1,
172
  total=total_steps,
 
253
  dataframe.reset_index(drop=True),
254
  features=features,
255
  )
256
+ dataset = combine_datasets(repo_id, dataset, oauth_token)
257
  distiset = Distiset({"default": dataset})
258
  progress(0.9, desc="Pushing dataset")
259
  distiset.push_to_hub(
 
665
 
666
  app.load(fn=swap_visibility, outputs=main_ui)
667
  app.load(fn=get_org_dropdown, outputs=[org_name])
668
+ app.load(fn=get_random_repo_name, outputs=[repo_name])
src/synthetic_dataset_generator/constants.py CHANGED
@@ -81,7 +81,7 @@ if OPENAI_BASE_URL:
81
 
82
  if not SFT_AVAILABLE:
83
  warnings.warn(
84
- "`SFT_AVAILABLE` is set to `False`. Use Hugging Face Inference Endpoints or Ollama to generate chat data, provide a `TOKENIZER_ID` and `MAGPIE_PRE_QUERY_TEMPLATE`."
85
  )
86
  MAGPIE_PRE_QUERY_TEMPLATE = None
87
 
 
81
 
82
  if not SFT_AVAILABLE:
83
  warnings.warn(
84
+ "`SFT_AVAILABLE` is set to `False`. Use Hugging Face Inference Endpoints or Ollama to generate chat data, provide a `TOKENIZER_ID` and `MAGPIE_PRE_QUERY_TEMPLATE`. You can also use `HUGGINGFACE_BASE_URL` to with vllm."
85
  )
86
  MAGPIE_PRE_QUERY_TEMPLATE = None
87
 
src/synthetic_dataset_generator/pipelines/base.py CHANGED
@@ -1,8 +1,13 @@
 
 
 
1
  import gradio as gr
2
  from distilabel.llms import InferenceEndpointsLLM, OllamaLLM, OpenAILLM
 
3
 
4
  from synthetic_dataset_generator.constants import (
5
  API_KEYS,
 
6
  HUGGINGFACE_BASE_URL,
7
  MAGPIE_PRE_QUERY_TEMPLATE,
8
  MODEL,
@@ -21,6 +26,41 @@ def _get_next_api_key():
21
  return api_key
22
 
23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  def _get_llm(use_magpie_template=False, **kwargs):
25
  if OPENAI_BASE_URL:
26
  llm = OpenAILLM(
 
1
+ import math
2
+ import random
3
+
4
  import gradio as gr
5
  from distilabel.llms import InferenceEndpointsLLM, OllamaLLM, OpenAILLM
6
+ from distilabel.steps.tasks import TextGeneration
7
 
8
  from synthetic_dataset_generator.constants import (
9
  API_KEYS,
10
+ DEFAULT_BATCH_SIZE,
11
  HUGGINGFACE_BASE_URL,
12
  MAGPIE_PRE_QUERY_TEMPLATE,
13
  MODEL,
 
26
  return api_key
27
 
28
 
29
+ def _get_prompt_rewriter():
30
+ generation_kwargs = {
31
+ "temperature": 1,
32
+ }
33
+ system_prompt = "You are a prompt rewriter. You are given a prompt and you need to rewrite it keeping the same structure but highlighting different aspects of the original without adding anything new."
34
+ prompt_rewriter = TextGeneration(
35
+ llm=_get_llm(generation_kwargs=generation_kwargs),
36
+ system_prompt=system_prompt,
37
+ use_system_prompt=True,
38
+ )
39
+ prompt_rewriter.load()
40
+ return prompt_rewriter
41
+
42
+
43
+ def get_rewriten_prompts(prompt: str, num_rows: int):
44
+ prompt_rewriter = _get_prompt_rewriter()
45
+ # create prompt rewrites
46
+ inputs = [
47
+ {"instruction": f"Original prompt: {prompt} \nRewritten prompt: "}
48
+ for i in range(math.floor(num_rows / 100))
49
+ ]
50
+ n_processed = 0
51
+ prompt_rewrites = [prompt]
52
+ while n_processed < num_rows:
53
+ batch = list(
54
+ prompt_rewriter.process(
55
+ inputs=inputs[n_processed : n_processed + DEFAULT_BATCH_SIZE]
56
+ )
57
+ )
58
+ prompt_rewrites += [entry["generation"] for entry in batch[0]]
59
+ n_processed += DEFAULT_BATCH_SIZE
60
+ random.seed(a=random.randint(0, 2**32 - 1))
61
+ return prompt_rewrites
62
+
63
+
64
  def _get_llm(use_magpie_template=False, **kwargs):
65
  if OPENAI_BASE_URL:
66
  llm = OpenAILLM(
src/synthetic_dataset_generator/pipelines/chat.py CHANGED
@@ -203,15 +203,6 @@ def get_magpie_generator(system_prompt, num_turns, temperature, is_sample):
203
  return magpie_generator
204
 
205
 
206
- def get_prompt_rewriter():
207
- generation_kwargs = {
208
- "temperature": 1,
209
- }
210
- prompt_rewriter = TextGeneration(llm=_get_llm(generation_kwargs=generation_kwargs))
211
- prompt_rewriter.load()
212
- return prompt_rewriter
213
-
214
-
215
  def get_response_generator(system_prompt, num_turns, temperature, is_sample):
216
  if num_turns == 1:
217
  generation_kwargs = {
 
203
  return magpie_generator
204
 
205
 
 
 
 
 
 
 
 
 
 
206
  def get_response_generator(system_prompt, num_turns, temperature, is_sample):
207
  if num_turns == 1:
208
  generation_kwargs = {
src/synthetic_dataset_generator/pipelines/textcat.py CHANGED
@@ -94,7 +94,6 @@ def get_textcat_generator(difficulty, clarity, temperature, is_sample):
94
  "top_p": 0.95,
95
  }
96
  llm = _get_llm(generation_kwargs=generation_kwargs)
97
-
98
  textcat_generator = GenerateTextClassificationData(
99
  llm=llm,
100
  difficulty=None if difficulty == "mixed" else difficulty,
@@ -111,7 +110,6 @@ def get_labeller_generator(system_prompt, labels, multi_label):
111
  "max_new_tokens": MAX_NUM_TOKENS,
112
  }
113
  llm = _get_llm(generation_kwargs=generation_kwargs)
114
-
115
  labeller_generator = TextClassification(
116
  llm=llm,
117
  context=system_prompt,
 
94
  "top_p": 0.95,
95
  }
96
  llm = _get_llm(generation_kwargs=generation_kwargs)
 
97
  textcat_generator = GenerateTextClassificationData(
98
  llm=llm,
99
  difficulty=None if difficulty == "mixed" else difficulty,
 
110
  "max_new_tokens": MAX_NUM_TOKENS,
111
  }
112
  llm = _get_llm(generation_kwargs=generation_kwargs)
 
113
  labeller_generator = TextClassification(
114
  llm=llm,
115
  context=system_prompt,
src/synthetic_dataset_generator/utils.py CHANGED
@@ -1,4 +1,5 @@
1
  import json
 
2
  import warnings
3
  from typing import List, Optional, Union
4
 
@@ -55,6 +56,10 @@ def list_orgs(oauth_token: Union[OAuthToken, None] = None):
55
  return organizations
56
 
57
 
 
 
 
 
58
  def get_org_dropdown(oauth_token: Union[OAuthToken, None] = None):
59
  if oauth_token is not None:
60
  orgs = list_orgs(oauth_token)
 
1
  import json
2
+ import uuid
3
  import warnings
4
  from typing import List, Optional, Union
5
 
 
56
  return organizations
57
 
58
 
59
+ def get_random_repo_name():
60
+ return f"my-distiset-{str(uuid.uuid4())[:8]}"
61
+
62
+
63
  def get_org_dropdown(oauth_token: Union[OAuthToken, None] = None):
64
  if oauth_token is not None:
65
  orgs = list_orgs(oauth_token)