Spaces:
Runtime error
Runtime error
import io | |
import multiprocessing | |
import os | |
import time | |
import gradio as gr | |
import pandas as pd | |
from unstructured.partition.pdf import partition_pdf | |
import nltk | |
from distilabel.pipeline import Pipeline | |
from distilabel.llms import InferenceEndpointsLLM | |
from distilabel.steps import LoadDataFromDicts, KeepColumns | |
from distilabel.steps.tasks import TextGeneration | |
from personas import * # Assuming this contains TextToPersona and other necessary definitions | |
nltk.download("punkt", quiet=True) | |
PROMPT_TEMPLATE = """\ | |
Generate a single prompt the persona below might ask to an AI assistant: | |
{{ persona }} | |
""" | |
# Get HF_TOKEN from environment variable | |
HF_TOKEN = os.environ.get("HF_TOKEN") | |
def process_pdfs(pdf_files): | |
all_data = [] | |
for pdf_file in pdf_files: | |
elements = partition_pdf(pdf_file.name) | |
full_text = "" | |
for element in elements: | |
full_text += element.text + "\n" | |
all_data.append({"text": full_text.strip()}) | |
return all_data | |
def _run_pipeline(result_queue, pdf_files): | |
data = process_pdfs(pdf_files) | |
with Pipeline(name="personahub-fineweb-edu-text-to-persona") as pipeline: | |
input_batch_size = 10 | |
data_loader = LoadDataFromDicts(data=data) | |
llm = InferenceEndpointsLLM( | |
model_id="meta-llama/Meta-Llama-3.1-8B-Instruct", | |
api_key=HF_TOKEN, | |
) | |
text_to_persona = TextToPersona( | |
llm=llm, | |
input_batch_size=input_batch_size, | |
) | |
text_gen = TextGeneration( | |
llm=llm, | |
system_prompt="You are an AI assistant expert at simulating user interactions.", | |
template=PROMPT_TEMPLATE, | |
columns="persona", | |
output_mappings={"generation": "instruction"}, | |
num_generations=1, | |
) | |
response_gen = TextGeneration( | |
llm=llm, | |
system_prompt="You are an AI assistant expert in responding to tasks", | |
output_mappings={"generation": "response"}, | |
) | |
keep = KeepColumns( | |
columns=["text", "persona", "model_name", "instruction", "response"], | |
input_batch_size=input_batch_size, | |
) | |
(data_loader >> text_to_persona >> text_gen >> response_gen >> keep) | |
distiset = pipeline.run(use_cache=False) | |
result_queue.put(distiset) | |
def generate_dataset(pdf_files, progress=gr.Progress()): | |
result_queue = multiprocessing.Queue() | |
p = multiprocessing.Process( | |
target=_run_pipeline, | |
args=(result_queue, pdf_files), | |
) | |
try: | |
p.start() | |
total_steps = 100 | |
for step in range(total_steps): | |
if not p.is_alive() or p._popen.poll() is not None: | |
break | |
progress( | |
(step + 1) / total_steps, | |
desc="Generating dataset. Don't close this window.", | |
) | |
time.sleep(2) # Adjust this value based on your needs | |
p.join() | |
except Exception as e: | |
raise gr.Error(f"An error occurred during dataset generation: {str(e)}") | |
distiset = result_queue.get() | |
df = distiset["default"]["train"].to_pandas() | |
progress(1.0, desc="Dataset generation completed") | |
return df | |
def gradio_interface(pdf_files): | |
if HF_TOKEN is None: | |
raise gr.Error( | |
"HF_TOKEN environment variable is not set. Please set it and restart the application." | |
) | |
df = generate_dataset(pdf_files) | |
return df | |
with gr.Blocks(title="MyPersonas Dataset Generator") as app: | |
gr.Markdown("# MyPersonas Dataset Generator") | |
gr.Markdown("Upload one or more PDFs to generate a persona based SFT dataset.") | |
with gr.Row(): | |
pdf_files = gr.File(label="Upload PDFs", file_count="multiple") | |
with gr.Row(): | |
generate_button = gr.Button("Generate Dataset") | |
output_dataframe = gr.DataFrame( | |
label="Generated Dataset", | |
interactive=False, | |
wrap=True, | |
) | |
generate_button.click( | |
fn=gradio_interface, | |
inputs=[pdf_files], | |
outputs=[output_dataframe], | |
) | |
if __name__ == "__main__": | |
app.launch() | |