sebaxakerhtc commited on
Commit
8291c8c
Β·
unverified Β·
1 Parent(s): a68cd13

Added local saving to CSV and JSON (#38)

Browse files

* Local save

- Added save local function to chat tab (CSV, JSON)
- Rebuild UI with new feature
- CSS edit for gr.File (perfectionism)

* Local save

* Mistake

* Update chat.py

* Local save RAG and Textcat

* Rebuild UI

* Show save_local only if save_local_dir is provided

src/synthetic_dataset_generator/app.py CHANGED
@@ -12,12 +12,13 @@ css = """
12
  .main_ui_logged_out{opacity: 0.3; pointer-events: none}
13
  button[role="tab"][aria-selected="true"] { border: 0; background: var(--button-primary-background-fill); color: white; border-top-right-radius: var(--radius-md); border-top-left-radius: var(--radius-md)}
14
  button[role="tab"][aria-selected="true"]:hover {border-color: var(--button-primary-background-fill); background: var(var(--button-primary-background-fill-hover))}
15
- .tabitem { border: 0; padding-inline: 0}
16
  .gallery-item {background: var(--background-fill-secondary); text-align: left}
17
- .table-wrap .tbody td { vertical-align: top }
18
- #system_prompt_examples { color: var(--body-text-color) !important; background-color: var(--block-background-fill) !important;}
19
  .container {padding-inline: 0 !important}
20
- #sign_in_button { flex-grow: 0; width: auto !important; display: flex; align-items: center; justify-content: center; margin: 0 auto; }
 
21
  """
22
 
23
  image = """<br><img src="https://raw.githubusercontent.com/argilla-io/synthetic-data-generator/main/assets/logo.svg" alt="Synthetic Data Generator Logo" style="display: block; margin-left: auto; margin-right: auto; width: clamp(50%, 400px, 100%)"/>"""
 
12
  .main_ui_logged_out{opacity: 0.3; pointer-events: none}
13
  button[role="tab"][aria-selected="true"] { border: 0; background: var(--button-primary-background-fill); color: white; border-top-right-radius: var(--radius-md); border-top-left-radius: var(--radius-md)}
14
  button[role="tab"][aria-selected="true"]:hover {border-color: var(--button-primary-background-fill); background: var(var(--button-primary-background-fill-hover))}
15
+ .tabitem {border: 0; padding-inline: 0}
16
  .gallery-item {background: var(--background-fill-secondary); text-align: left}
17
+ .table-wrap .tbody td {vertical-align: top}
18
+ #system_prompt_examples {color: var(--body-text-color) !important; background-color: var(--block-background-fill) !important;}
19
  .container {padding-inline: 0 !important}
20
+ #sign_in_button {flex-grow: 0; width: auto !important; display: flex; align-items: center; justify-content: center; margin: 0 auto;}
21
+ .datasets {height: 70px;}
22
  """
23
 
24
  image = """<br><img src="https://raw.githubusercontent.com/argilla-io/synthetic-data-generator/main/assets/logo.svg" alt="Synthetic Data Generator Logo" style="display: block; margin-left: auto; margin-right: auto; width: clamp(50%, 400px, 100%)"/>"""
src/synthetic_dataset_generator/apps/base.py CHANGED
@@ -12,9 +12,13 @@ from huggingface_hub import HfApi, upload_file, repo_exists
12
  from unstructured.chunking.title import chunk_by_title
13
  from unstructured.partition.auto import partition
14
 
15
- from synthetic_dataset_generator.constants import MAX_NUM_ROWS
16
  from synthetic_dataset_generator.utils import get_argilla_client
17
 
 
 
 
 
18
 
19
  def validate_argilla_user_workspace_dataset(
20
  dataset_name: str,
 
12
  from unstructured.chunking.title import chunk_by_title
13
  from unstructured.partition.auto import partition
14
 
15
+ from synthetic_dataset_generator.constants import MAX_NUM_ROWS, SAVE_LOCAL_DIR
16
  from synthetic_dataset_generator.utils import get_argilla_client
17
 
18
+ if SAVE_LOCAL_DIR is not None:
19
+ import os
20
+ os.makedirs(SAVE_LOCAL_DIR, exist_ok=True)
21
+
22
 
23
  def validate_argilla_user_workspace_dataset(
24
  dataset_name: str,
src/synthetic_dataset_generator/apps/chat.py CHANGED
@@ -2,6 +2,7 @@ import ast
2
  import json
3
  import random
4
  import uuid
 
5
  from typing import Dict, List, Union
6
 
7
  import argilla as rg
@@ -30,6 +31,7 @@ from synthetic_dataset_generator.constants import (
30
  MODEL,
31
  MODEL_COMPLETION,
32
  SFT_AVAILABLE,
 
33
  )
34
  from synthetic_dataset_generator.pipelines.base import get_rewritten_prompts
35
  from synthetic_dataset_generator.pipelines.chat import (
@@ -264,7 +266,6 @@ def generate_dataset_from_prompt(
264
  progress(1.0, desc="Dataset generation completed")
265
  return dataframe
266
 
267
-
268
  def generate_dataset_from_seed(
269
  dataframe: pd.DataFrame,
270
  document_column: str,
@@ -506,7 +507,7 @@ def push_dataset(
506
  num_turns=num_turns,
507
  num_rows=num_rows,
508
  temperature=temperature,
509
- temperature_completion=temperature_completion
510
  )
511
  push_dataset_to_hub(
512
  dataframe=dataframe,
@@ -637,6 +638,45 @@ def push_dataset(
637
  return ""
638
 
639
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
640
  def show_system_prompt_visibility():
641
  return {system_prompt: gr.Textbox(visible=True)}
642
 
@@ -670,6 +710,13 @@ def hide_pipeline_code_visibility():
670
  def show_temperature_completion():
671
  if MODEL != MODEL_COMPLETION:
672
  return {temperature_completion: gr.Slider(value=0.9, visible=True)}
 
 
 
 
 
 
 
673
 
674
 
675
  ######################
@@ -852,6 +899,11 @@ with gr.Blocks() as app:
852
  btn_push_to_hub = gr.Button(
853
  "Push to Hub", variant="primary", scale=2
854
  )
 
 
 
 
 
855
  with gr.Column(scale=3):
856
  success_message = gr.Markdown(
857
  visible=True,
@@ -998,6 +1050,23 @@ with gr.Blocks() as app:
998
  inputs=[],
999
  outputs=[pipeline_code_ui],
1000
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1001
 
1002
  clear_dataset_btn_part.click(fn=lambda: "", inputs=[], outputs=[search_in])
1003
  clear_file_btn_part.click(fn=lambda: None, inputs=[], outputs=[file_in])
@@ -1011,3 +1080,5 @@ with gr.Blocks() as app:
1011
  app.load(fn=get_org_dropdown, outputs=[org_name])
1012
  app.load(fn=get_random_repo_name, outputs=[repo_name])
1013
  app.load(fn=show_temperature_completion, outputs=[temperature_completion])
 
 
 
2
  import json
3
  import random
4
  import uuid
5
+ import os
6
  from typing import Dict, List, Union
7
 
8
  import argilla as rg
 
31
  MODEL,
32
  MODEL_COMPLETION,
33
  SFT_AVAILABLE,
34
+ SAVE_LOCAL_DIR,
35
  )
36
  from synthetic_dataset_generator.pipelines.base import get_rewritten_prompts
37
  from synthetic_dataset_generator.pipelines.chat import (
 
266
  progress(1.0, desc="Dataset generation completed")
267
  return dataframe
268
 
 
269
  def generate_dataset_from_seed(
270
  dataframe: pd.DataFrame,
271
  document_column: str,
 
507
  num_turns=num_turns,
508
  num_rows=num_rows,
509
  temperature=temperature,
510
+ temperature_completion=temperature_completion,
511
  )
512
  push_dataset_to_hub(
513
  dataframe=dataframe,
 
638
  return ""
639
 
640
 
641
+ def save_local(
642
+ repo_id: str,
643
+ file_paths: list[str],
644
+ input_type: str,
645
+ system_prompt: str,
646
+ document_column: str,
647
+ num_turns: int,
648
+ num_rows: int,
649
+ temperature: float,
650
+ repo_name: str,
651
+ temperature_completion: Union[float, None] = None,
652
+ ) -> pd.DataFrame:
653
+ if input_type == "prompt-input":
654
+ dataframe = _get_dataframe()
655
+ else:
656
+ dataframe, _ = load_dataset_file(
657
+ repo_id=repo_id,
658
+ file_paths=file_paths,
659
+ input_type=input_type,
660
+ num_rows=num_rows,
661
+ )
662
+ dataframe = generate_dataset(
663
+ input_type=input_type,
664
+ dataframe=dataframe,
665
+ system_prompt=system_prompt,
666
+ document_column=document_column,
667
+ num_turns=num_turns,
668
+ num_rows=num_rows,
669
+ temperature=temperature,
670
+ temperature_completion=temperature_completion
671
+ )
672
+ local_dataset = Dataset.from_pandas(dataframe)
673
+ output_csv = os.path.join(SAVE_LOCAL_DIR, repo_name + ".csv")
674
+ output_json = os.path.join(SAVE_LOCAL_DIR, repo_name + ".json")
675
+ local_dataset.to_csv(output_csv, index=False)
676
+ local_dataset.to_json(output_json, index=False)
677
+ return output_csv, output_json
678
+
679
+
680
  def show_system_prompt_visibility():
681
  return {system_prompt: gr.Textbox(visible=True)}
682
 
 
710
  def show_temperature_completion():
711
  if MODEL != MODEL_COMPLETION:
712
  return {temperature_completion: gr.Slider(value=0.9, visible=True)}
713
+
714
+ def show_save_local():
715
+ return {
716
+ btn_save_local: gr.Button(visible=True),
717
+ csv_file: gr.File(visible=True),
718
+ json_file: gr.File(visible=True)
719
+ }
720
 
721
 
722
  ######################
 
899
  btn_push_to_hub = gr.Button(
900
  "Push to Hub", variant="primary", scale=2
901
  )
902
+ btn_save_local = gr.Button(
903
+ "Save locally", variant="primary", scale=2, visible=False
904
+ )
905
+ csv_file = gr.File(label="CSV", elem_classes="datasets", visible=False)
906
+ json_file = gr.File(label="JSON", elem_classes="datasets", visible=False)
907
  with gr.Column(scale=3):
908
  success_message = gr.Markdown(
909
  visible=True,
 
1050
  inputs=[],
1051
  outputs=[pipeline_code_ui],
1052
  )
1053
+
1054
+ btn_save_local.click(
1055
+ save_local,
1056
+ inputs=[
1057
+ search_in,
1058
+ file_in,
1059
+ input_type,
1060
+ system_prompt,
1061
+ document_column,
1062
+ num_turns,
1063
+ num_rows,
1064
+ temperature,
1065
+ repo_name,
1066
+ temperature_completion,
1067
+ ],
1068
+ outputs=[csv_file, json_file]
1069
+ )
1070
 
1071
  clear_dataset_btn_part.click(fn=lambda: "", inputs=[], outputs=[search_in])
1072
  clear_file_btn_part.click(fn=lambda: None, inputs=[], outputs=[file_in])
 
1080
  app.load(fn=get_org_dropdown, outputs=[org_name])
1081
  app.load(fn=get_random_repo_name, outputs=[repo_name])
1082
  app.load(fn=show_temperature_completion, outputs=[temperature_completion])
1083
+ if SAVE_LOCAL_DIR is not None:
1084
+ app.load(fn=show_save_local, outputs=[btn_save_local, csv_file, json_file])
src/synthetic_dataset_generator/apps/rag.py CHANGED
@@ -24,7 +24,7 @@ from synthetic_dataset_generator.apps.base import (
24
  validate_argilla_user_workspace_dataset,
25
  validate_push_to_hub,
26
  )
27
- from synthetic_dataset_generator.constants import DEFAULT_BATCH_SIZE, MODEL, MODEL_COMPLETION
28
  from synthetic_dataset_generator.pipelines.base import get_rewritten_prompts
29
  from synthetic_dataset_generator.pipelines.embeddings import (
30
  get_embeddings,
@@ -486,6 +486,49 @@ def push_dataset(
486
  return ""
487
 
488
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
489
  def show_system_prompt_visibility():
490
  return {system_prompt: gr.Textbox(visible=True)}
491
 
@@ -521,6 +564,14 @@ def show_temperature_completion():
521
  return {temperature_completion: gr.Slider(value=0.9, visible=True)}
522
 
523
 
 
 
 
 
 
 
 
 
524
  ######################
525
  # Gradio UI
526
  ######################
@@ -674,7 +725,14 @@ with gr.Blocks() as app:
674
  interactive=True,
675
  scale=1,
676
  )
677
- btn_push_to_hub = gr.Button("Push to Hub", variant="primary", scale=2)
 
 
 
 
 
 
 
678
  with gr.Column(scale=3):
679
  success_message = gr.Markdown(
680
  visible=True,
@@ -822,6 +880,23 @@ with gr.Blocks() as app:
822
  outputs=[pipeline_code_ui],
823
  )
824
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
825
  clear_dataset_btn_part.click(fn=lambda: "", inputs=[], outputs=[search_in])
826
  clear_file_btn_part.click(fn=lambda: None, inputs=[], outputs=[file_in])
827
  clear_prompt_btn_part.click(fn=lambda: "", inputs=[], outputs=[dataset_description])
@@ -835,3 +910,5 @@ with gr.Blocks() as app:
835
  app.load(fn=get_org_dropdown, outputs=[org_name])
836
  app.load(fn=get_random_repo_name, outputs=[repo_name])
837
  app.load(fn=show_temperature_completion, outputs=[temperature_completion])
 
 
 
24
  validate_argilla_user_workspace_dataset,
25
  validate_push_to_hub,
26
  )
27
+ from synthetic_dataset_generator.constants import DEFAULT_BATCH_SIZE, MODEL, MODEL_COMPLETION, SAVE_LOCAL_DIR
28
  from synthetic_dataset_generator.pipelines.base import get_rewritten_prompts
29
  from synthetic_dataset_generator.pipelines.embeddings import (
30
  get_embeddings,
 
486
  return ""
487
 
488
 
489
+ def save_local(
490
+ repo_id: str,
491
+ file_paths: list[str],
492
+ input_type: str,
493
+ system_prompt: str,
494
+ document_column: str,
495
+ retrieval_reranking: list[str],
496
+ num_rows: int,
497
+ temperature: float,
498
+ repo_name: str,
499
+ temperature_completion: float,
500
+ ) -> pd.DataFrame:
501
+ retrieval = "Retrieval" in retrieval_reranking
502
+ reranking = "Reranking" in retrieval_reranking
503
+
504
+ if input_type == "prompt-input":
505
+ dataframe = pd.DataFrame(columns=["context", "question", "response"])
506
+ else:
507
+ dataframe, _ = load_dataset_file(
508
+ repo_id=repo_id,
509
+ file_paths=file_paths,
510
+ input_type=input_type,
511
+ num_rows=num_rows,
512
+ )
513
+ dataframe = generate_dataset(
514
+ input_type=input_type,
515
+ dataframe=dataframe,
516
+ system_prompt=system_prompt,
517
+ document_column=document_column,
518
+ retrieval=retrieval,
519
+ reranking=reranking,
520
+ num_rows=num_rows,
521
+ temperature=temperature,
522
+ temperature_completion=temperature_completion,
523
+ )
524
+ local_dataset = Dataset.from_pandas(dataframe)
525
+ output_csv = os.path.join(SAVE_LOCAL_DIR, repo_name + ".csv")
526
+ output_json = os.path.join(SAVE_LOCAL_DIR, repo_name + ".json")
527
+ local_dataset.to_csv(output_csv, index=False)
528
+ local_dataset.to_json(output_json, index=False)
529
+ return output_csv, output_json
530
+
531
+
532
  def show_system_prompt_visibility():
533
  return {system_prompt: gr.Textbox(visible=True)}
534
 
 
564
  return {temperature_completion: gr.Slider(value=0.9, visible=True)}
565
 
566
 
567
+ def show_save_local():
568
+ return {
569
+ btn_save_local: gr.Button(visible=True),
570
+ csv_file: gr.File(visible=True),
571
+ json_file: gr.File(visible=True)
572
+ }
573
+
574
+
575
  ######################
576
  # Gradio UI
577
  ######################
 
725
  interactive=True,
726
  scale=1,
727
  )
728
+ btn_push_to_hub = gr.Button(
729
+ "Push to Hub", variant="primary", scale=2
730
+ )
731
+ btn_save_local = gr.Button(
732
+ "Save locally", variant="primary", scale=2, visible=False
733
+ )
734
+ csv_file = gr.File(label="CSV", elem_classes="datasets", visible=False)
735
+ json_file = gr.File(label="JSON", elem_classes="datasets", visible=False)
736
  with gr.Column(scale=3):
737
  success_message = gr.Markdown(
738
  visible=True,
 
880
  outputs=[pipeline_code_ui],
881
  )
882
 
883
+ btn_save_local.click(
884
+ save_local,
885
+ inputs=[
886
+ search_in,
887
+ file_in,
888
+ input_type,
889
+ system_prompt,
890
+ document_column,
891
+ retrieval_reranking,
892
+ num_rows,
893
+ temperature,
894
+ repo_name,
895
+ temperature_completion,
896
+ ],
897
+ outputs=[csv_file, json_file]
898
+ )
899
+
900
  clear_dataset_btn_part.click(fn=lambda: "", inputs=[], outputs=[search_in])
901
  clear_file_btn_part.click(fn=lambda: None, inputs=[], outputs=[file_in])
902
  clear_prompt_btn_part.click(fn=lambda: "", inputs=[], outputs=[dataset_description])
 
910
  app.load(fn=get_org_dropdown, outputs=[org_name])
911
  app.load(fn=get_random_repo_name, outputs=[repo_name])
912
  app.load(fn=show_temperature_completion, outputs=[temperature_completion])
913
+ if SAVE_LOCAL_DIR is not None:
914
+ app.load(fn=show_save_local, outputs=[btn_save_local, csv_file, json_file])
src/synthetic_dataset_generator/apps/textcat.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import json
2
  import random
3
  import uuid
@@ -19,7 +20,7 @@ from synthetic_dataset_generator.apps.base import (
19
  validate_argilla_user_workspace_dataset,
20
  validate_push_to_hub,
21
  )
22
- from synthetic_dataset_generator.constants import DEFAULT_BATCH_SIZE
23
  from synthetic_dataset_generator.pipelines.base import get_rewritten_prompts
24
  from synthetic_dataset_generator.pipelines.embeddings import (
25
  get_embeddings,
@@ -406,6 +407,33 @@ def push_dataset(
406
  return ""
407
 
408
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
409
  def validate_input_labels(labels: List[str]) -> List[str]:
410
  if (
411
  not labels
@@ -425,6 +453,14 @@ def hide_pipeline_code_visibility():
425
  return {pipeline_code_ui: gr.Accordion(visible=False)}
426
 
427
 
 
 
 
 
 
 
 
 
428
  ######################
429
  # Gradio UI
430
  ######################
@@ -543,7 +579,14 @@ with gr.Blocks() as app:
543
  interactive=True,
544
  scale=1,
545
  )
546
- btn_push_to_hub = gr.Button("Push to Hub", variant="primary", scale=2)
 
 
 
 
 
 
 
547
  with gr.Column(scale=3):
548
  success_message = gr.Markdown(
549
  visible=True,
@@ -643,6 +686,21 @@ with gr.Blocks() as app:
643
  inputs=[],
644
  outputs=[pipeline_code_ui],
645
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
646
 
647
  gr.on(
648
  triggers=[clear_btn_part.click, clear_btn_full.click],
@@ -660,3 +718,5 @@ with gr.Blocks() as app:
660
  app.load(fn=swap_visibility, outputs=main_ui)
661
  app.load(fn=get_org_dropdown, outputs=[org_name])
662
  app.load(fn=get_random_repo_name, outputs=[repo_name])
 
 
 
1
+ import os
2
  import json
3
  import random
4
  import uuid
 
20
  validate_argilla_user_workspace_dataset,
21
  validate_push_to_hub,
22
  )
23
+ from synthetic_dataset_generator.constants import DEFAULT_BATCH_SIZE, SAVE_LOCAL_DIR
24
  from synthetic_dataset_generator.pipelines.base import get_rewritten_prompts
25
  from synthetic_dataset_generator.pipelines.embeddings import (
26
  get_embeddings,
 
407
  return ""
408
 
409
 
410
+ def save_local(
411
+ system_prompt: str,
412
+ difficulty: str,
413
+ clarity: str,
414
+ labels: List[str],
415
+ multi_label: bool,
416
+ num_rows: int,
417
+ temperature: float,
418
+ repo_name: str,
419
+ ) -> pd.DataFrame:
420
+ dataframe = generate_dataset(
421
+ system_prompt=system_prompt,
422
+ difficulty=difficulty,
423
+ clarity=clarity,
424
+ multi_label=multi_label,
425
+ labels=labels,
426
+ num_rows=num_rows,
427
+ temperature=temperature,
428
+ )
429
+ local_dataset = Dataset.from_pandas(dataframe)
430
+ output_csv = os.path.join(SAVE_LOCAL_DIR, repo_name + ".csv")
431
+ output_json = os.path.join(SAVE_LOCAL_DIR, repo_name + ".json")
432
+ local_dataset.to_csv(output_csv, index=False)
433
+ local_dataset.to_json(output_json, index=False)
434
+ return output_csv, output_json
435
+
436
+
437
  def validate_input_labels(labels: List[str]) -> List[str]:
438
  if (
439
  not labels
 
453
  return {pipeline_code_ui: gr.Accordion(visible=False)}
454
 
455
 
456
+ def show_save_local():
457
+ return {
458
+ btn_save_local: gr.Button(visible=True),
459
+ csv_file: gr.File(visible=True),
460
+ json_file: gr.File(visible=True)
461
+ }
462
+
463
+
464
  ######################
465
  # Gradio UI
466
  ######################
 
579
  interactive=True,
580
  scale=1,
581
  )
582
+ btn_push_to_hub = gr.Button(
583
+ "Push to Hub", variant="primary", scale=2
584
+ )
585
+ btn_save_local = gr.Button(
586
+ "Save locally", variant="primary", scale=2, visible=False
587
+ )
588
+ csv_file = gr.File(label="CSV", elem_classes="datasets", visible=False)
589
+ json_file = gr.File(label="JSON", elem_classes="datasets", visible=False)
590
  with gr.Column(scale=3):
591
  success_message = gr.Markdown(
592
  visible=True,
 
686
  inputs=[],
687
  outputs=[pipeline_code_ui],
688
  )
689
+
690
+ btn_save_local.click(
691
+ save_local,
692
+ inputs=[
693
+ system_prompt,
694
+ difficulty,
695
+ clarity,
696
+ labels,
697
+ multi_label,
698
+ num_rows,
699
+ temperature,
700
+ repo_name,
701
+ ],
702
+ outputs=[csv_file, json_file]
703
+ )
704
 
705
  gr.on(
706
  triggers=[clear_btn_part.click, clear_btn_full.click],
 
718
  app.load(fn=swap_visibility, outputs=main_ui)
719
  app.load(fn=get_org_dropdown, outputs=[org_name])
720
  app.load(fn=get_random_repo_name, outputs=[repo_name])
721
+ if SAVE_LOCAL_DIR is not None:
722
+ app.load(fn=show_save_local, outputs=[btn_save_local, csv_file, json_file])
src/synthetic_dataset_generator/constants.py CHANGED
@@ -8,6 +8,9 @@ MAX_NUM_TOKENS = int(os.getenv("MAX_NUM_TOKENS", 2048))
8
  MAX_NUM_ROWS = int(os.getenv("MAX_NUM_ROWS", 1000))
9
  DEFAULT_BATCH_SIZE = int(os.getenv("DEFAULT_BATCH_SIZE", 5))
10
 
 
 
 
11
  # Models
12
  MODEL = os.getenv("MODEL", "meta-llama/Meta-Llama-3.1-8B-Instruct")
13
  TOKENIZER_ID = os.getenv(key="TOKENIZER_ID", default=None)
 
8
  MAX_NUM_ROWS = int(os.getenv("MAX_NUM_ROWS", 1000))
9
  DEFAULT_BATCH_SIZE = int(os.getenv("DEFAULT_BATCH_SIZE", 5))
10
 
11
+ # Directory for outputs
12
+ SAVE_LOCAL_DIR = os.getenv(key="SAVE_LOCAL_DIR", default=None)
13
+
14
  # Models
15
  MODEL = os.getenv("MODEL", "meta-llama/Meta-Llama-3.1-8B-Instruct")
16
  TOKENIZER_ID = os.getenv(key="TOKENIZER_ID", default=None)