lhoestq HF staff commited on
Commit
06ee5d2
·
1 Parent(s): bd6dcb9

finalize save dataset

Browse files
Files changed (1) hide show
  1. app.py +64 -25
app.py CHANGED
@@ -11,6 +11,7 @@ import ijson
11
  import pandas as pd
12
  import requests
13
  from datasets import Dataset, Features, Value, Sequence
 
14
  from gradio_huggingfacehub_search import HuggingfaceHubSearch
15
  from huggingface_hub import InferenceClient
16
 
@@ -18,16 +19,19 @@ from utils import StringIteratorIO
18
 
19
 
20
  model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
21
- client = InferenceClient(model_id, token=os.environ.get("HF_TOKEN_INFERENCE_API"))
22
 
23
  save_dataset_hf_token = os.environ.get("SAVE_DATASET_HF_TOKEN")
24
  session = requests.Session()
25
  empty_dataframe = pd.DataFrame({"1": [], "2": [], "3": []})
26
 
27
- NAMESPACE = "lhoestq"
28
 
29
  NUM_ROWS_PREVIEW = 3
30
- MAX_NUM_ROWS_TO_REWRITE = 10
 
 
 
31
  REWRITE_DATASET_PREVIEW = (
32
  "A Machine Learning practitioner is looking for a dataset similar to '{dataset}' but slightly different. "
33
  "They want you to rewrite the dataset and apply this instruction, which can be about transforming, translating or filtering the rows: {prompt}."
@@ -77,7 +81,7 @@ with gr.Blocks(css=css) as demo:
77
  subset_dropdown = gr.Dropdown(info="Subset", show_label=False, visible=False)
78
  split_dropdown = gr.Dropdown(info="Split", show_label=False, visible=False)
79
 
80
- gr.Markdown("### Input")
81
  pretty_input_preview = gr.DataFrame(interactive=False)
82
 
83
  gr.Markdown("### ReWrite")
@@ -86,10 +90,12 @@ with gr.Blocks(css=css) as demo:
86
  with gr.Accordion("(Advanced) Edit columns", open=False):
87
  output_format_dataframe = gr.DataFrame(col_count=(2, "fixed"), headers=["column", "type"])
88
  rewrite_preview_button = gr.Button("Preview Results", variant="primary")
89
- pretty_output_preview = gr.DataFrame(interactive=False)
90
  rewrite_full_dataset_button = gr.Button("ReWrite Full Dataset", interactive=False)
 
91
  full_dataset_generation_label = gr.Label(visible=False, show_label=False)
 
92
  full_dataset_generation_success_markdown = gr.Markdown("")
 
93
  with gr.Column(scale=4, min_width="200px"):
94
  with gr.Accordion("Settings", open=False, elem_classes="settings"):
95
  gr.Markdown("Save datasets to your account")
@@ -156,9 +162,9 @@ with gr.Blocks(css=css) as demo:
156
  prompt=prompt,
157
  )}]
158
  response_format = {"type": "json", "value": {"properties": {"data": {"type": "array", "items": format}}, "required": ["data"]}}
159
- print("streaming preview")
160
  yield from ijson.items(StringIteratorIO(stream_reponse(messages, response_format=response_format)), "data.item", buf_size=4)
161
- print("done")
162
 
163
 
164
  def stream_rewrite_dataset_row_by_row(dataset: str, rows: list[dict[str, str]], prompt: str, format: str, input_preview_rows: list[dict[str, str]], output_preview_rows: list[dict[str, str]]) -> Iterator[dict[str, str]]:
@@ -171,18 +177,17 @@ with gr.Blocks(css=css) as demo:
171
  output_preview_rows=json.dumps({"data": output_preview_rows}),
172
  )}]
173
  response_format = {"type": "json", "value": {"properties": {"data": {"type": "array", "items": format}}, "required": ["data"]}}
174
- print("streaming results")
175
  yield from ijson.items(StringIteratorIO(stream_reponse(messages, response_format=response_format)), "data.item", buf_size=4)
176
- print("done")
177
 
178
 
179
  def find_new_name(dataset: str, prompt: str) -> str:
180
  messages = [{"role": "user", "content": FIND_NEW_NAME.format(prompt=prompt)}]
181
  out = "".join(stream_reponse(messages))
182
  if "should be" in out:
183
- return dataset.split("/")[-1] + out.split("should be", 1)[1].replace(" ", "-").replace(".", "").replace(",", "")
184
  else:
185
- return dataset.split("/")[-1] + prompt.replace(" ", "-")
 
186
 
187
  def _write_generator_to_queue(queue: Queue, func: Callable[..., Iterable], kwargs: dict) -> None:
188
  for i, result in enumerate(func(**kwargs)):
@@ -293,7 +298,7 @@ with gr.Blocks(css=css) as demo:
293
  return {rewrite_full_dataset_button: gr.Button(interactive=False)}
294
 
295
 
296
- @rewrite_preview_button.click(inputs=[dataset_search, pretty_input_preview, input_prompt, output_format_dataframe], outputs=[pretty_output_preview, rewrite_full_dataset_button, full_dataset_generation_label])
297
  def rewrite_preview(dataset: str, pretty_input_preview_df: pd.DataFrame, prompt: str, output_format_df: pd.DataFrame) -> Iterator[pd.DataFrame]:
298
  rows = [{k: json.loads(v) for k, v in row.items()} for row in pretty_input_preview_df.to_dict(orient="records")]
299
  format = output_format_df.to_dict(orient="records")
@@ -301,34 +306,68 @@ with gr.Blocks(css=css) as demo:
301
  output_rows = []
302
  print(f"ReWriting {dataset} preview with instruction '{prompt}'")
303
  yield {rewrite_full_dataset_button: gr.Button(interactive=False), full_dataset_generation_label: gr.Label(visible=False)}
 
 
 
 
304
  for row in stream_rewrite_dataset_preview_row_by_row(dataset=dataset, rows=rows, prompt=prompt, format=format):
305
  output_rows.append({k: json.dumps(row[k], ensure_ascii=False) for k in output_format_df["column"]})
306
  yield {pretty_output_preview: gr.DataFrame(pd.DataFrame(output_rows))}
307
  yield {rewrite_full_dataset_button: gr.Button(interactive=True)}
308
 
309
 
310
- @rewrite_full_dataset_button.click(inputs=[dataset_search, subset_dropdown, split_dropdown, pretty_input_preview, pretty_output_preview, input_prompt, output_format_dataframe, dataset_info_json, select_namespace_dropdown], outputs=[full_dataset_generation_label, full_dataset_generation_success_markdown])
311
  def rewrite_full_dataset(dataset: str, subset: str, split: str, pretty_input_preview_df: pd.DataFrame, pretty_output_preview_df: pd.DataFrame, prompt: str, output_format_df: pd.DataFrame, dataset_info: dict[str, Any], namespace: str, oauth_token: Optional[gr.OAuthToken]) -> Iterator[pd.DataFrame]:
312
  input_preview_rows = [{k: json.loads(v) for k, v in row.items()} for row in pretty_input_preview_df.to_dict(orient="records")]
313
  output_preview_rows = [{k: json.loads(v) for k, v in row.items()} for row in pretty_output_preview_df.to_dict(orient="records")]
314
  format = output_format_df.to_dict(orient="records")
315
  format = {"properties": {x["column"]: json.loads(x["type"]) for x in format}, "required": [x["column"] for x in format]}
316
- output_rows = []
317
  num_examples = dataset_info["splits"][split]["num_examples"]
318
  total = min(num_examples, MAX_NUM_ROWS_TO_REWRITE)
319
  print(f"ReWriting {dataset} (full dataset) with instruction '{prompt}'")
320
  yield {full_dataset_generation_label: gr.Label({f"⚙️ ReWriting {dataset}": 0.}, visible=True)}
321
- for rows in batched(islice(stream_rows(dataset=dataset, subset=subset, split=split), total), n=10):
322
- for row in stream_rewrite_dataset_row_by_row(dataset=dataset, rows=rows, prompt=prompt, format=format, input_preview_rows=input_preview_rows, output_preview_rows=output_preview_rows):
323
- print(row)
324
- output_rows.append({k: json.dumps(row[k], ensure_ascii=False) for k in output_format_df["column"]})
325
- yield {full_dataset_generation_label: gr.Label({f"⚙️ ReWriting {dataset}": len(output_rows) / total})}
326
- # repo_id = namespace + "/" + find_new_name(dataset, prompt)
327
- # yield {full_dataset_generation_label: gr.Label({f"✅ ReWriting {dataset}": len(output_rows) / total, f"⚙️ Saving to {repo_id}": 0.})}
328
- # token = oauth_token.token if oauth_token else save_dataset_hf_token
329
- # Dataset.from_list(output_rows).push_to_hub(repo_id, config_name=subset, split=split, token=token)
330
- # yield {full_dataset_generation_label: gr.Label({f"✅ ReWriting {dataset}": len(output_rows) / total, f"✅ Saving to {repo_id}": 1.})}
331
- # yield {full_dataset_generation_success_markdown: f"# Open the ReWriten dataset in a new tab: [{repo_id}](https://huggingface.co/datasets/{repo_id})"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
332
 
333
 
334
  demo.launch()
 
11
  import pandas as pd
12
  import requests
13
  from datasets import Dataset, Features, Value, Sequence
14
+ from datasets.fingerprint import Hasher
15
  from gradio_huggingfacehub_search import HuggingfaceHubSearch
16
  from huggingface_hub import InferenceClient
17
 
 
19
 
20
 
21
  model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
22
+ client = InferenceClient(model_id, token=os.environ.get("INFERENCE_API_HF_TOKEN"))
23
 
24
  save_dataset_hf_token = os.environ.get("SAVE_DATASET_HF_TOKEN")
25
  session = requests.Session()
26
  empty_dataframe = pd.DataFrame({"1": [], "2": [], "3": []})
27
 
28
+ NAMESPACE = "dataset-rewriter"
29
 
30
  NUM_ROWS_PREVIEW = 3
31
+ MAX_NUM_ROWS_TO_REWRITE = 1000
32
+ NUM_PARALLEL_CALLS = 10
33
+ NUM_ROWS_PER_CALL = 10
34
+ MAX_PROGRESS_UPDATES_PER_SECOND = 4
35
  REWRITE_DATASET_PREVIEW = (
36
  "A Machine Learning practitioner is looking for a dataset similar to '{dataset}' but slightly different. "
37
  "They want you to rewrite the dataset and apply this instruction, which can be about transforming, translating or filtering the rows: {prompt}."
 
81
  subset_dropdown = gr.Dropdown(info="Subset", show_label=False, visible=False)
82
  split_dropdown = gr.Dropdown(info="Split", show_label=False, visible=False)
83
 
84
+ gr.Markdown("### Sample")
85
  pretty_input_preview = gr.DataFrame(interactive=False)
86
 
87
  gr.Markdown("### ReWrite")
 
90
  with gr.Accordion("(Advanced) Edit columns", open=False):
91
  output_format_dataframe = gr.DataFrame(col_count=(2, "fixed"), headers=["column", "type"])
92
  rewrite_preview_button = gr.Button("Preview Results", variant="primary")
 
93
  rewrite_full_dataset_button = gr.Button("ReWrite Full Dataset", interactive=False)
94
+ gr.Markdown("#### Output")
95
  full_dataset_generation_label = gr.Label(visible=False, show_label=False)
96
+ pretty_output_preview = gr.DataFrame(interactive=False)
97
  full_dataset_generation_success_markdown = gr.Markdown("")
98
+ pretty_full_dataset_generation_output = gr.DataFrame(interactive=False, visible=False)
99
  with gr.Column(scale=4, min_width="200px"):
100
  with gr.Accordion("Settings", open=False, elem_classes="settings"):
101
  gr.Markdown("Save datasets to your account")
 
162
  prompt=prompt,
163
  )}]
164
  response_format = {"type": "json", "value": {"properties": {"data": {"type": "array", "items": format}}, "required": ["data"]}}
165
+ print(f"Streaming preview of {dataset} with instruction '{prompt}'")
166
  yield from ijson.items(StringIteratorIO(stream_reponse(messages, response_format=response_format)), "data.item", buf_size=4)
167
+ print(f"Done streaming preview of {dataset} with instruction '{prompt}'")
168
 
169
 
170
  def stream_rewrite_dataset_row_by_row(dataset: str, rows: list[dict[str, str]], prompt: str, format: str, input_preview_rows: list[dict[str, str]], output_preview_rows: list[dict[str, str]]) -> Iterator[dict[str, str]]:
 
177
  output_preview_rows=json.dumps({"data": output_preview_rows}),
178
  )}]
179
  response_format = {"type": "json", "value": {"properties": {"data": {"type": "array", "items": format}}, "required": ["data"]}}
 
180
  yield from ijson.items(StringIteratorIO(stream_reponse(messages, response_format=response_format)), "data.item", buf_size=4)
 
181
 
182
 
183
  def find_new_name(dataset: str, prompt: str) -> str:
184
  messages = [{"role": "user", "content": FIND_NEW_NAME.format(prompt=prompt)}]
185
  out = "".join(stream_reponse(messages))
186
  if "should be" in out:
187
+ out = dataset.split("/")[-1] + out.split("should be", 1)[1].replace(" ", "-").replace(".", "").replace(",", "")
188
  else:
189
+ out = dataset.split("/")[-1] + prompt.replace(" ", "-")
190
+ return out[:90] + "-" + Hasher.hash(prompt)[:4]
191
 
192
  def _write_generator_to_queue(queue: Queue, func: Callable[..., Iterable], kwargs: dict) -> None:
193
  for i, result in enumerate(func(**kwargs)):
 
298
  return {rewrite_full_dataset_button: gr.Button(interactive=False)}
299
 
300
 
301
+ @rewrite_preview_button.click(inputs=[dataset_search, pretty_input_preview, input_prompt, output_format_dataframe], outputs=[pretty_output_preview, rewrite_full_dataset_button, full_dataset_generation_label, pretty_full_dataset_generation_output])
302
  def rewrite_preview(dataset: str, pretty_input_preview_df: pd.DataFrame, prompt: str, output_format_df: pd.DataFrame) -> Iterator[pd.DataFrame]:
303
  rows = [{k: json.loads(v) for k, v in row.items()} for row in pretty_input_preview_df.to_dict(orient="records")]
304
  format = output_format_df.to_dict(orient="records")
 
306
  output_rows = []
307
  print(f"ReWriting {dataset} preview with instruction '{prompt}'")
308
  yield {rewrite_full_dataset_button: gr.Button(interactive=False), full_dataset_generation_label: gr.Label(visible=False)}
309
+ yield {
310
+ pretty_output_preview: gr.DataFrame(visible=True),
311
+ pretty_full_dataset_generation_output: gr.DataFrame(visible=False),
312
+ }
313
  for row in stream_rewrite_dataset_preview_row_by_row(dataset=dataset, rows=rows, prompt=prompt, format=format):
314
  output_rows.append({k: json.dumps(row[k], ensure_ascii=False) for k in output_format_df["column"]})
315
  yield {pretty_output_preview: gr.DataFrame(pd.DataFrame(output_rows))}
316
  yield {rewrite_full_dataset_button: gr.Button(interactive=True)}
317
 
318
 
319
+ @rewrite_full_dataset_button.click(inputs=[dataset_search, subset_dropdown, split_dropdown, pretty_input_preview, pretty_output_preview, input_prompt, output_format_dataframe, dataset_info_json, select_namespace_dropdown], outputs=[full_dataset_generation_label, full_dataset_generation_success_markdown, pretty_output_preview, pretty_full_dataset_generation_output])
320
  def rewrite_full_dataset(dataset: str, subset: str, split: str, pretty_input_preview_df: pd.DataFrame, pretty_output_preview_df: pd.DataFrame, prompt: str, output_format_df: pd.DataFrame, dataset_info: dict[str, Any], namespace: str, oauth_token: Optional[gr.OAuthToken]) -> Iterator[pd.DataFrame]:
321
  input_preview_rows = [{k: json.loads(v) for k, v in row.items()} for row in pretty_input_preview_df.to_dict(orient="records")]
322
  output_preview_rows = [{k: json.loads(v) for k, v in row.items()} for row in pretty_output_preview_df.to_dict(orient="records")]
323
  format = output_format_df.to_dict(orient="records")
324
  format = {"properties": {x["column"]: json.loads(x["type"]) for x in format}, "required": [x["column"] for x in format]}
 
325
  num_examples = dataset_info["splits"][split]["num_examples"]
326
  total = min(num_examples, MAX_NUM_ROWS_TO_REWRITE)
327
  print(f"ReWriting {dataset} (full dataset) with instruction '{prompt}'")
328
  yield {full_dataset_generation_label: gr.Label({f"⚙️ ReWriting {dataset}": 0.}, visible=True)}
329
+ yield {pretty_full_dataset_generation_output: empty_dataframe}
330
+ yield {
331
+ pretty_output_preview: gr.DataFrame(visible=False),
332
+ pretty_full_dataset_generation_output: gr.DataFrame(visible=True),
333
+ }
334
+
335
+ num_parallel_calls = max(1, min(total // NUM_ROWS_PER_CALL, NUM_PARALLEL_CALLS))
336
+ parallel_input_rows = list(batched(islice(stream_rows(dataset=dataset, subset=subset, split=split), total), n=total // num_parallel_calls))
337
+ parallel_output_rows = [[] for _ in range(num_parallel_calls)]
338
+
339
+ def run(i):
340
+ for batch_rows in batched(parallel_input_rows[i], n=NUM_ROWS_PER_CALL):
341
+ for row in stream_rewrite_dataset_row_by_row(dataset=dataset, rows=batch_rows, prompt=prompt, format=format, input_preview_rows=input_preview_rows, output_preview_rows=output_preview_rows):
342
+ parallel_output_rows[i].append({k: json.dumps(row[k], ensure_ascii=False) for k in output_format_df["column"]})
343
+ yield 1
344
+
345
+ current = 0
346
+ _last_time = time.time()
347
+ for step in iflatmap_unordered(run, kwargs_iterable=[{"i": i} for i in range(num_parallel_calls)]):
348
+ current += step
349
+ if _last_time + 1 / MAX_PROGRESS_UPDATES_PER_SECOND < time.time():
350
+ _last_time = time.time()
351
+ yield {
352
+ full_dataset_generation_label: gr.Label({f"⚙️ ReWriting {dataset}": current / total}),
353
+ pretty_full_dataset_generation_output: gr.DataFrame(pd.DataFrame([row for rows in parallel_output_rows for row in rows]))
354
+ }
355
+ yield {
356
+ full_dataset_generation_label: gr.Label({f"⚙️ ReWriting {dataset}": current / total}),
357
+ pretty_full_dataset_generation_output: gr.DataFrame(pd.DataFrame([row for rows in parallel_output_rows for row in rows]))
358
+ }
359
+ print(f"Done ReWriting {dataset} (full dataset) with instruction '{prompt}'")
360
+
361
+ output_rows = [{k: json.loads(row[k]) for k in output_format_df["column"]} for rows in parallel_output_rows for row in rows]
362
+ repo_id = namespace + "/" + find_new_name(dataset, prompt)
363
+ yield {full_dataset_generation_label: gr.Label({f"✅ ReWriting {dataset}": len(output_rows) / total, f"⚙️ Saving to {repo_id}": 0.})}
364
+ token = oauth_token.token if oauth_token else save_dataset_hf_token
365
+ print(f"Saving {repo_id}")
366
+ ds = Dataset.from_list(output_rows)
367
+ ds.push_to_hub(repo_id, config_name=subset, split=split, token=token)
368
+ yield {full_dataset_generation_label: gr.Label({f"✅ ReWriting {dataset}": len(output_rows) / total, f"✅ Saving to {repo_id}": 1.})}
369
+ yield {full_dataset_generation_success_markdown: f"# Open the ReWriten dataset in a new tab: [{repo_id}](https://huggingface.co/datasets/{repo_id})"}
370
+ print(f"Saved {repo_id}")
371
 
372
 
373
  demo.launch()