| import streamlit as st | |
| import os | |
| import pandas as pd | |
| from command_center import CommandCenter | |
| from process_documents import process_documents, num_tokens | |
| from embed_documents import create_retriever | |
| import json | |
| from langchain.callbacks import get_openai_callback | |
| from langchain_openai import ChatOpenAI | |
| import base64 | |
| from chat_chains import ( | |
| parse_model_response, | |
| qa_chain, | |
| format_docs, | |
| parse_context_and_question, | |
| ai_response_format, | |
| ) | |
| from autoqa_chain import auto_qa_chain | |
| from chain_of_density import chain_of_density_chain | |
| from insights_bullet_chain import insights_bullet_chain | |
| from insights_mind_map_chain import insights_mind_map_chain | |
| from synopsis_chain import synopsis_chain | |
| from custom_exceptions import InvalidArgumentError, InvalidCommandError | |
| from openai_configuration import openai_parser | |
| from summary_chain import summary_chain | |
| from tldr_chain import tldr_chain | |
| st.set_page_config(layout="wide") | |
| welcome_message = """ | |
| Hi I'm Agent Zeta, your AI assistant, dedicated to making your journey through machine learning research papers as insightful and interactive as possible. | |
| Whether you're diving into the latest studies or brushing up on foundational papers, I'm here to help navigate, discuss, and analyze content with you. | |
| Here's a quick guide to getting started with me: | |
| | Command | Description | | |
| |---------|-------------| | |
| | `/configure --key <api key> --model <model>` | Configure the OpenAI API key and model for our conversation. | | |
| | `/add-papers <list of urls>` | Upload and process documents for our conversation. | | |
| | `/library` | View an index of processed documents to easily navigate your research. | | |
| | `/view-snip <snippet id>` | View the content of a specific snnippet. | | |
| | `/session-expense` | Calculate the cost of our conversation, ensuring transparency in resource usage. | | |
| | `/export` | Download conversation data for your records or further analysis. | | |
| | `/auto-insight <list of snippet ids>` | Automatically generate questions and answers for the paper. | | |
| | `/condense-summary <list of snippet ids>` | Generate increasingly concise, entity-dense summaries of the paper. | | |
| | `/insight-bullets <list of snippet ids>` | Extract and summarize key insights, methods, results, and conclusions. | | |
| | `/insight-mind-map <list of snippet ids>` | Create a structured outline of the key insights in Markdown format. | | |
| | `/paper-synopsis <list of snippet ids>` | Generate a synopsis of the paper. | | |
| | `/deep-dive [<list of snippet ids>] <query>` | Query me with a specific context. | | |
| | `/summarise-section [<list of snippet ids>] <section name>` | Summarize a specific section of the paper. | | |
| | `/tldr <list of snippet ids>` | Generate a tldr summary of the paper. | | |
| <br> | |
| Feel free to use these commands to enhance your research experience. Let's embark on this exciting journey of discovery together! | |
| Use `/help-me` at any point of time to view this guide again. | |
| """ | |
| def process_documents_wrapper(inputs): | |
| if inputs == []: | |
| raise InvalidArgumentError("Please provide document urls") | |
| snippets, documents = process_documents(inputs) | |
| st.session_state.retriever = create_retriever(snippets) | |
| st.session_state.source_doc_urls = inputs | |
| st.session_state.index = [ | |
| [ | |
| snip.metadata["chunk_id"], | |
| snip.metadata["header"], | |
| num_tokens(snip.page_content), | |
| ] | |
| for snip in snippets | |
| ] | |
| response = f"Uploaded and processed documents {inputs}" | |
| st.session_state.messages.append((f"/add-papers {inputs}", response, "identity")) | |
| st.session_state.documents = documents | |
| return (response, "identity") | |
| def index_documents_wrapper(inputs=None): | |
| response = pd.DataFrame( | |
| st.session_state.index, columns=["id", "reference", "tokens"] | |
| ) | |
| st.session_state.messages.append(("/library", response, "dataframe")) | |
| return (response, "dataframe") | |
| def view_document_wrapper(inputs): | |
| response = st.session_state.documents[inputs].page_content | |
| st.session_state.messages.append((f"/view-snip {inputs}", response, "identity")) | |
| return (response, "identity") | |
| def calculate_cost_wrapper(inputs=None): | |
| try: | |
| stats_df = pd.DataFrame(st.session_state.costing) | |
| stats_df.loc["total"] = stats_df.sum() | |
| response = stats_df | |
| except ValueError: | |
| response = "No cost incurred yet" | |
| st.session_state.messages.append(("/session-expense", response, "dataframe")) | |
| return (response, "dataframe") | |
| def download_conversation_wrapper(inputs=None): | |
| conversation_data = json.dumps( | |
| { | |
| "document_urls": ( | |
| st.session_state.source_doc_urls | |
| if "source_doc_urls" in st.session_state | |
| else [] | |
| ), | |
| "document_snippets": ( | |
| st.session_state.index if "index" in st.session_state else [] | |
| ), | |
| "conversation": [ | |
| {"human": message[0], "ai": jsonify_functions[message[2]](message[1])} | |
| for message in st.session_state.messages | |
| ], | |
| "costing": ( | |
| st.session_state.costing if "costing" in st.session_state else [] | |
| ), | |
| "total_cost": ( | |
| { | |
| k: sum(d[k] for d in st.session_state.costing) | |
| for k in st.session_state.costing[0] | |
| } | |
| if "costing" in st.session_state and len(st.session_state.costing) > 0 | |
| else {} | |
| ), | |
| } | |
| ) | |
| conversation_data = base64.b64encode(conversation_data.encode()).decode() | |
| st.session_state.messages.append( | |
| ("/export", "Conversation data downloaded", "identity") | |
| ) | |
| return ( | |
| f'<a href="data:text/csv;base64,{conversation_data}" download="conversation_data.json">Download Conversation</a>', | |
| "identity", | |
| ) | |
| def query_llm(inputs, relevant_docs): | |
| with get_openai_callback() as cb: | |
| response = ( | |
| qa_chain(ChatOpenAI(model=st.session_state.model, temperature=0)) | |
| .invoke({"context": format_docs(relevant_docs), "question": inputs}) | |
| .content | |
| ) | |
| stats = cb | |
| response = parse_model_response(response) | |
| answer = response["answer"] | |
| citations = response["citations"] | |
| citations.append( | |
| { | |
| "source_id": " ".join( | |
| [ | |
| f"[{ref}]" | |
| for ref in sorted( | |
| [str(ref.metadata["chunk_id"]) for ref in relevant_docs], | |
| ) | |
| ] | |
| ), | |
| "quote": "other sources", | |
| } | |
| ) | |
| st.session_state.messages.append( | |
| (inputs, {"answer": answer, "citations": citations}, "reponse_with_citations") | |
| ) | |
| st.session_state.costing.append( | |
| { | |
| "prompt tokens": stats.prompt_tokens, | |
| "completion tokens": stats.completion_tokens, | |
| "cost": stats.total_cost, | |
| } | |
| ) | |
| return ({"answer": answer, "citations": citations}, "reponse_with_citations") | |
| def rag_llm_wrapper(inputs): | |
| retriever = st.session_state.retriever | |
| relevant_docs = retriever.get_relevant_documents(inputs) | |
| return query_llm(inputs, relevant_docs) | |
| def query_llm_wrapper(inputs): | |
| context, question = parse_context_and_question(inputs) | |
| relevant_docs = [st.session_state.documents[c] for c in context] | |
| return query_llm(question, relevant_docs) | |
| def summarise_wrapper(inputs): | |
| context, query = parse_context_and_question(inputs) | |
| document = [st.session_state.documents[c] for c in context] | |
| llm = ChatOpenAI(model=st.session_state.model, temperature=0) | |
| with get_openai_callback() as cb: | |
| summary = summary_chain(llm).invoke({"section_name": query, "paper": document}) | |
| stats = cb | |
| st.session_state.messages.append( | |
| (f"/summarise-section {query}", summary, "identity") | |
| ) | |
| st.session_state.costing.append( | |
| { | |
| "prompt tokens": stats.prompt_tokens, | |
| "completion tokens": stats.completion_tokens, | |
| "cost": stats.total_cost, | |
| } | |
| ) | |
| return (summary, "identity") | |
| def chain_of_density_wrapper(inputs): | |
| if inputs == []: | |
| raise InvalidArgumentError("Please provide snippet ids") | |
| document = "\n\n".join([st.session_state.documents[c].page_content for c in inputs]) | |
| llm = ChatOpenAI(model=st.session_state.model, temperature=0) | |
| with get_openai_callback() as cb: | |
| summary = chain_of_density_chain(llm).invoke({"paper": document}) | |
| stats = cb | |
| st.session_state.messages.append(("/condense-summary", summary, "identity")) | |
| st.session_state.costing.append( | |
| { | |
| "prompt tokens": stats.prompt_tokens, | |
| "completion tokens": stats.completion_tokens, | |
| "cost": stats.total_cost, | |
| } | |
| ) | |
| return (summary, "identity") | |
| def synopsis_wrapper(inputs): | |
| if inputs == []: | |
| raise InvalidArgumentError("Please provide snippet ids") | |
| document = "\n\n".join([st.session_state.documents[c].page_content for c in inputs]) | |
| llm = ChatOpenAI(model=st.session_state.model, temperature=0) | |
| with get_openai_callback() as cb: | |
| summary = synopsis_chain(llm).invoke({"paper": document}) | |
| stats = cb | |
| st.session_state.messages.append(("/paper-synopsis", summary, "identity")) | |
| st.session_state.costing.append( | |
| { | |
| "prompt tokens": stats.prompt_tokens, | |
| "completion tokens": stats.completion_tokens, | |
| "cost": stats.total_cost, | |
| } | |
| ) | |
| return (summary, "identity") | |
| def tldr_wrapper(inputs): | |
| if inputs == []: | |
| raise InvalidArgumentError("Please provide snippet ids") | |
| document = "\n\n".join([st.session_state.documents[c].page_content for c in inputs]) | |
| llm = ChatOpenAI(model=st.session_state.model, temperature=0) | |
| with get_openai_callback() as cb: | |
| summary = tldr_chain(llm).invoke({"paper": document}) | |
| stats = cb | |
| st.session_state.messages.append(("/tldr", summary, "identity")) | |
| st.session_state.costing.append( | |
| { | |
| "prompt tokens": stats.prompt_tokens, | |
| "completion tokens": stats.completion_tokens, | |
| "cost": stats.total_cost, | |
| } | |
| ) | |
| return (summary, "identity") | |
| def insights_bullet_wrapper(inputs): | |
| if inputs == []: | |
| raise InvalidArgumentError("Please provide snippet ids") | |
| document = "\n\n".join([st.session_state.documents[c].page_content for c in inputs]) | |
| llm = ChatOpenAI(model=st.session_state.model, temperature=0) | |
| with get_openai_callback() as cb: | |
| insights = insights_bullet_chain(llm).invoke({"paper": document}) | |
| stats = cb | |
| st.session_state.messages.append(("/insight-bullets", insights, "identity")) | |
| st.session_state.costing.append( | |
| { | |
| "prompt tokens": stats.prompt_tokens, | |
| "completion tokens": stats.completion_tokens, | |
| "cost": stats.total_cost, | |
| } | |
| ) | |
| return (insights, "identity") | |
| def insights_mind_map_wrapper(inputs): | |
| if inputs == []: | |
| raise InvalidArgumentError("Please provide snippet ids") | |
| document = "\n\n".join([st.session_state.documents[c].page_content for c in inputs]) | |
| llm = ChatOpenAI(model=st.session_state.model, temperature=0) | |
| with get_openai_callback() as cb: | |
| insights = insights_mind_map_chain(llm).invoke({"paper": document}) | |
| stats = cb | |
| st.session_state.messages.append(("/insight-mind-map", insights, "identity")) | |
| st.session_state.costing.append( | |
| { | |
| "prompt tokens": stats.prompt_tokens, | |
| "completion tokens": stats.completion_tokens, | |
| "cost": stats.total_cost, | |
| } | |
| ) | |
| return (insights, "identity") | |
| def auto_qa_chain_wrapper(inputs): | |
| if inputs == []: | |
| raise InvalidArgumentError("Please provide snippet ids") | |
| document = "\n\n".join([st.session_state.documents[c].page_content for c in inputs]) | |
| llm = ChatOpenAI(model=st.session_state.model, temperature=0) | |
| retriever = st.session_state.retriever | |
| formatted_response = "" | |
| with get_openai_callback() as cb: | |
| auto_qa_response = auto_qa_chain(llm).invoke({"paper": document}) | |
| stats = cb | |
| for section in auto_qa_response: | |
| section_name = section["section_name"] | |
| formatted_response += f"# {section_name}\n" | |
| for question in section["questions"]: | |
| response = ( | |
| qa_chain(ChatOpenAI(model=st.session_state.model, temperature=0)) | |
| .invoke( | |
| { | |
| "context": format_docs( | |
| retriever.get_relevant_documents(question) | |
| ), | |
| "question": question, | |
| } | |
| ) | |
| .content | |
| ) | |
| answer = parse_model_response(response)["answer"] | |
| formatted_response += f"## {question}\n" | |
| formatted_response += f"* {answer}\n" | |
| formatted_response = "```\n" + formatted_response + "\n```" | |
| st.session_state.messages.append( | |
| (f"/auto-insight {inputs}", formatted_response, "identity") | |
| ) | |
| st.session_state.costing.append( | |
| { | |
| "prompt tokens": stats.prompt_tokens, | |
| "completion tokens": stats.completion_tokens, | |
| "cost": stats.total_cost, | |
| } | |
| ) | |
| return ( | |
| formatted_response, | |
| "identity", | |
| ) | |
| def boot(command_center, formating_functions): | |
| st.write("# Agent Zeta") | |
| if "costing" not in st.session_state: | |
| st.session_state.costing = [] | |
| if "messages" not in st.session_state: | |
| st.session_state.messages = [] | |
| st.chat_message("ai").write(welcome_message, unsafe_allow_html=True) | |
| for message in st.session_state.messages: | |
| st.chat_message("human").write(message[0]) | |
| st.chat_message("ai").write( | |
| formating_functions[message[2]](message[1]), unsafe_allow_html=True | |
| ) | |
| if query := st.chat_input(): | |
| try: | |
| st.chat_message("human").write(query) | |
| response, format_fn_name = command_center.execute_command(query) | |
| st.chat_message("ai").write( | |
| formating_functions[format_fn_name](response), unsafe_allow_html=True | |
| ) | |
| except (InvalidArgumentError, InvalidCommandError) as e: | |
| st.error(e) | |
| def configure_openai_wrapper(inputs): | |
| args = openai_parser.parse_args(inputs.split()) | |
| os.environ["OPENAI_API_KEY"] = args.key | |
| st.session_state.model = args.model | |
| st.session_state.messages.append(("/configure", "Configurations Saved", "identity")) | |
| return (str(args), "identity") | |
| if __name__ == "__main__": | |
| all_commands = [ | |
| ("/configure", str, configure_openai_wrapper), | |
| ("/add-papers", list, process_documents_wrapper), | |
| ("/library", None, index_documents_wrapper), | |
| ("/view-snip", str, view_document_wrapper), | |
| ("/session-expense", None, calculate_cost_wrapper), | |
| ("/export", None, download_conversation_wrapper), | |
| ("/help-me", None, lambda x: (welcome_message, "identity")), | |
| ("/auto-insight", list, auto_qa_chain_wrapper), | |
| ("/deep-dive", str, query_llm_wrapper), | |
| ("/condense-summary", list, chain_of_density_wrapper), | |
| ("/insight-bullets", list, insights_bullet_wrapper), | |
| ("/insight-mind-map", list, insights_mind_map_wrapper), | |
| ("/paper-synopsis", list, synopsis_wrapper), | |
| ("/summarise-section", str, summarise_wrapper), | |
| ("/tldr", list, tldr_wrapper), | |
| ] | |
| command_center = CommandCenter( | |
| default_input_type=str, | |
| default_function=rag_llm_wrapper, | |
| all_commands=all_commands, | |
| ) | |
| formating_functions = { | |
| "identity": lambda x: x, | |
| "dataframe": lambda x: x, | |
| "reponse_with_citations": lambda x: ai_response_format( | |
| x["answer"], x["citations"] | |
| ), | |
| } | |
| jsonify_functions = { | |
| "identity": lambda x: x, | |
| "dataframe": lambda x: ( | |
| x.to_dict(orient="records") | |
| if isinstance(x, pd.DataFrame) or isinstance(x, pd.Series) | |
| else x | |
| ), | |
| "reponse_with_citations": lambda x: x, | |
| } | |
| boot(command_center, formating_functions) | |