| 
							 | 
						import gradio as gr | 
					
					
						
						| 
							 | 
						import pandas as pd | 
					
					
						
						| 
							 | 
						import pixeltable as pxt | 
					
					
						
						| 
							 | 
						from pixeltable.iterators import DocumentSplitter | 
					
					
						
						| 
							 | 
						import numpy as np | 
					
					
						
						| 
							 | 
						from pixeltable.functions.huggingface import sentence_transformer | 
					
					
						
						| 
							 | 
						from pixeltable.functions import openai | 
					
					
						
						| 
							 | 
						import os | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						"""## Store OpenAI API Key""" | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						if 'OPENAI_API_KEY' not in os.environ: | 
					
					
						
						| 
							 | 
						    os.environ['OPENAI_API_KEY'] = getpass.getpass('Enter your OpenAI API key:') | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						"""Pixeltable Set up""" | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						pxt.drop_dir('rag_demo', force=True) | 
					
					
						
						| 
							 | 
						pxt.create_dir('rag_demo') | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						@pxt.expr_udf | 
					
					
						
						| 
							 | 
						def e5_embed(text: str) -> np.ndarray: | 
					
					
						
						| 
							 | 
						    return sentence_transformer(text, model_id='intfloat/e5-large-v2') | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						@pxt.udf | 
					
					
						
						| 
							 | 
						def create_prompt(top_k_list: list[dict], question: str) -> str: | 
					
					
						
						| 
							 | 
						    concat_top_k = '\n\n'.join( | 
					
					
						
						| 
							 | 
						        elt['text'] for elt in reversed(top_k_list) | 
					
					
						
						| 
							 | 
						    ) | 
					
					
						
						| 
							 | 
						    return f''' | 
					
					
						
						| 
							 | 
						    PASSAGES: | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						    {concat_top_k} | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						    QUESTION: | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						    {question}''' | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						def process_files(ground_truth_file, pdf_files): | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    pxt.drop_dir('rag_demo', force=True) | 
					
					
						
						| 
							 | 
						    pxt.create_dir('rag_demo') | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    if ground_truth_file.name.endswith('.csv'): | 
					
					
						
						| 
							 | 
						        queries_t = pxt.io.import_csv('rag_demo.queries', ground_truth_file.name) | 
					
					
						
						| 
							 | 
						    else: | 
					
					
						
						| 
							 | 
						        queries_t = pxt.io.import_excel('rag_demo.queries', ground_truth_file.name) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    documents_t = pxt.create_table( | 
					
					
						
						| 
							 | 
						        'rag_demo.documents', | 
					
					
						
						| 
							 | 
						        {'document': pxt.DocumentType()} | 
					
					
						
						| 
							 | 
						    ) | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    documents_t.insert({'document': file.name} for file in pdf_files if file.name.endswith('.pdf')) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    chunks_t = pxt.create_view( | 
					
					
						
						| 
							 | 
						        'rag_demo.chunks', | 
					
					
						
						| 
							 | 
						        documents_t, | 
					
					
						
						| 
							 | 
						        iterator=DocumentSplitter.create( | 
					
					
						
						| 
							 | 
						            document=documents_t.document, | 
					
					
						
						| 
							 | 
						            separators='token_limit', | 
					
					
						
						| 
							 | 
						            limit=300 | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						    ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    chunks_t.add_embedding_index('text', string_embed=e5_embed) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    @chunks_t.query | 
					
					
						
						| 
							 | 
						    def top_k(query_text: str): | 
					
					
						
						| 
							 | 
						      sim = chunks_t.text.similarity(query_text) | 
					
					
						
						| 
							 | 
						      return ( | 
					
					
						
						| 
							 | 
						          chunks_t.order_by(sim, asc=False) | 
					
					
						
						| 
							 | 
						              .select(chunks_t.text, sim=sim) | 
					
					
						
						| 
							 | 
						              .limit(5) | 
					
					
						
						| 
							 | 
						      ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    queries_t['question_context'] = chunks_t.top_k(queries_t.Question) | 
					
					
						
						| 
							 | 
						    queries_t['prompt'] = create_prompt( | 
					
					
						
						| 
							 | 
						        queries_t.question_context, queries_t.Question | 
					
					
						
						| 
							 | 
						    ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    messages = [ | 
					
					
						
						| 
							 | 
						        { | 
					
					
						
						| 
							 | 
						            'role': 'system', | 
					
					
						
						| 
							 | 
						            'content': 'Please read the following passages and answer the question based on their contents.' | 
					
					
						
						| 
							 | 
						        }, | 
					
					
						
						| 
							 | 
						        { | 
					
					
						
						| 
							 | 
						            'role': 'user', | 
					
					
						
						| 
							 | 
						            'content': queries_t.prompt | 
					
					
						
						| 
							 | 
						        } | 
					
					
						
						| 
							 | 
						    ] | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						      | 
					
					
						
						| 
							 | 
						    queries_t['response'] = openai.chat_completions( | 
					
					
						
						| 
							 | 
						        model='gpt-4o-mini-2024-07-18', messages=messages | 
					
					
						
						| 
							 | 
						    ) | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    queries_t['answer'] = queries_t.response.choices[0].message.content.astype(pxt.StringType()) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    df_output = queries_t.select(queries_t.Question, queries_t.correct_answer, queries_t.answer).collect().to_pandas() | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    try: | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        return df_output | 
					
					
						
						| 
							 | 
						    except Exception as e: | 
					
					
						
						| 
							 | 
						        return f"An error occurred: {str(e)}", None | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						with gr.Blocks() as demo: | 
					
					
						
						| 
							 | 
						    gr.Markdown("# RAG Demo App") | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    with gr.Row(): | 
					
					
						
						| 
							 | 
						        ground_truth_file = gr.File(label="Upload Ground Truth (CSV or XLSX)", file_count="single") | 
					
					
						
						| 
							 | 
						        pdf_files = gr.File(label="Upload PDF Documents", file_count="multiple") | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    process_button = gr.Button("Process Files and Generate Outputs") | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    df_output = gr.DataFrame(label="Pixeltable Table") | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    process_button.click(process_files, inputs=[ground_truth_file, pdf_files], outputs=df_output) | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						if __name__ == "__main__": | 
					
					
						
						| 
							 | 
						    demo.launch() |