burtenshaw's picture
burtenshaw HF staff
first commit
7b3a105
import io
import multiprocessing
import os
import time
import gradio as gr
import pandas as pd
from unstructured.partition.pdf import partition_pdf
import nltk
from distilabel.pipeline import Pipeline
from distilabel.llms import InferenceEndpointsLLM
from distilabel.steps import LoadDataFromDicts, KeepColumns
from distilabel.steps.tasks import TextGeneration
from personas import * # Assuming this contains TextToPersona and other necessary definitions
nltk.download("punkt", quiet=True)
PROMPT_TEMPLATE = """\
Generate a single prompt the persona below might ask to an AI assistant:
{{ persona }}
"""
# Get HF_TOKEN from environment variable
HF_TOKEN = os.environ.get("HF_TOKEN")
def process_pdfs(pdf_files):
all_data = []
for pdf_file in pdf_files:
elements = partition_pdf(pdf_file.name)
full_text = ""
for element in elements:
full_text += element.text + "\n"
all_data.append({"text": full_text.strip()})
return all_data
def _run_pipeline(result_queue, pdf_files):
data = process_pdfs(pdf_files)
with Pipeline(name="personahub-fineweb-edu-text-to-persona") as pipeline:
input_batch_size = 10
data_loader = LoadDataFromDicts(data=data)
llm = InferenceEndpointsLLM(
model_id="meta-llama/Meta-Llama-3.1-8B-Instruct",
api_key=HF_TOKEN,
)
text_to_persona = TextToPersona(
llm=llm,
input_batch_size=input_batch_size,
)
text_gen = TextGeneration(
llm=llm,
system_prompt="You are an AI assistant expert at simulating user interactions.",
template=PROMPT_TEMPLATE,
columns="persona",
output_mappings={"generation": "instruction"},
num_generations=1,
)
response_gen = TextGeneration(
llm=llm,
system_prompt="You are an AI assistant expert in responding to tasks",
output_mappings={"generation": "response"},
)
keep = KeepColumns(
columns=["text", "persona", "model_name", "instruction", "response"],
input_batch_size=input_batch_size,
)
(data_loader >> text_to_persona >> text_gen >> response_gen >> keep)
distiset = pipeline.run(use_cache=False)
result_queue.put(distiset)
def generate_dataset(pdf_files, progress=gr.Progress()):
result_queue = multiprocessing.Queue()
p = multiprocessing.Process(
target=_run_pipeline,
args=(result_queue, pdf_files),
)
try:
p.start()
total_steps = 100
for step in range(total_steps):
if not p.is_alive() or p._popen.poll() is not None:
break
progress(
(step + 1) / total_steps,
desc="Generating dataset. Don't close this window.",
)
time.sleep(2) # Adjust this value based on your needs
p.join()
except Exception as e:
raise gr.Error(f"An error occurred during dataset generation: {str(e)}")
distiset = result_queue.get()
df = distiset["default"]["train"].to_pandas()
progress(1.0, desc="Dataset generation completed")
return df
def gradio_interface(pdf_files):
if HF_TOKEN is None:
raise gr.Error(
"HF_TOKEN environment variable is not set. Please set it and restart the application."
)
df = generate_dataset(pdf_files)
return df
with gr.Blocks(title="MyPersonas Dataset Generator") as app:
gr.Markdown("# MyPersonas Dataset Generator")
gr.Markdown("Upload one or more PDFs to generate a persona based SFT dataset.")
with gr.Row():
pdf_files = gr.File(label="Upload PDFs", file_count="multiple")
with gr.Row():
generate_button = gr.Button("Generate Dataset")
output_dataframe = gr.DataFrame(
label="Generated Dataset",
interactive=False,
wrap=True,
)
generate_button.click(
fn=gradio_interface,
inputs=[pdf_files],
outputs=[output_dataframe],
)
if __name__ == "__main__":
app.launch()