umer70112254 commited on
Commit
03fcf77
·
1 Parent(s): 365661e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +117 -132
app.py CHANGED
@@ -1,159 +1,144 @@
1
- import os
2
- from typing import Dict, Any
3
- import asyncio
4
 
5
- # Check if the OPENAI_API_KEY is set
6
- if "OPENAI_API_KEY" not in os.environ:
7
- raise ValueError("No OPENAI_API_KEY found. Please set the environment variable.")
8
 
9
- # Create a new event loop
10
- loop = asyncio.new_event_loop()
 
11
 
12
- # Set the event loop as the current event loop
13
- asyncio.set_event_loop(loop)
14
 
15
- from llama_index import (
16
- VectorStoreIndex,
17
- ServiceContext,
18
- download_loader,
19
- )
20
- from llama_index.llama_pack.base import BaseLlamaPack
21
- from llama_index.llms import OpenAI
22
 
23
 
24
- class StreamlitChatPack(BaseLlamaPack):
25
- """Streamlit chatbot pack."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
  def __init__(
28
  self,
29
- wikipedia_page: str = "Snowflake Inc.",
30
- run_from_main: bool = False,
31
  **kwargs: Any,
32
  ) -> None:
33
  """Init params."""
34
- if not run_from_main:
35
- raise ValueError(
36
- "Please run this llama-pack directly with "
37
- "`streamlit run [download_dir]/streamlit_chatbot/base.py`"
38
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
- self.wikipedia_page = wikipedia_page
 
41
 
42
  def get_modules(self) -> Dict[str, Any]:
43
  """Get modules."""
44
- return {}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
  def run(self, *args: Any, **kwargs: Any) -> Any:
47
  """Run the pipeline."""
48
- import streamlit as st
49
- from streamlit_pills import pills
50
-
51
- st.set_page_config(
52
- page_title=f"Chat with {self.wikipedia_page}'s Wikipedia page, powered by LlamaIndex",
53
- page_icon="🦙",
54
- layout="centered",
55
- initial_sidebar_state="auto",
56
- menu_items=None,
57
- )
58
-
59
- if "messages" not in st.session_state: # Initialize the chat messages history
60
- st.session_state["messages"] = [
61
- {"role": "assistant", "content": "Ask me a question about Snowflake!"}
62
- ]
63
 
64
- st.title(
65
- f"Chat with {self.wikipedia_page}'s Wikipedia page, powered by LlamaIndex 💬🦙"
 
66
  )
67
- st.info(
68
- "This example is powered by the **[Llama Hub Wikipedia Loader](https://llamahub.ai/l/wikipedia)**. Use any of [Llama Hub's many loaders](https://llamahub.ai/) to retrieve and chat with your data via a Streamlit app.",
69
- icon="ℹ️",
70
- )
71
-
72
- def add_to_message_history(role, content):
73
- message = {"role": role, "content": str(content)}
74
- st.session_state["messages"].append(
75
- message
76
- ) # Add response to message history
77
-
78
- @st.cache_resource
79
- def load_index_data():
80
- WikipediaReader = download_loader(
81
- "WikipediaReader", custom_path="local_dir"
82
- )
83
- loader = WikipediaReader()
84
- docs = loader.load_data(pages=[self.wikipedia_page])
85
- service_context = ServiceContext.from_defaults(
86
- llm=OpenAI(model="gpt-3.5-turbo", temperature=0.5)
87
- )
88
- index = VectorStoreIndex.from_documents(
89
- docs, service_context=service_context
90
  )
91
- return index
92
-
93
- index = load_index_data()
94
-
95
- selected = pills(
96
- "Choose a question to get started or write your own below.",
97
- [
98
- "What is Snowflake?",
99
- "What company did Snowflake announce they would acquire in October 2023?",
100
- "What company did Snowflake acquire in March 2022?",
101
- "When did Snowflake IPO?",
102
- ],
103
- clearable=True,
104
- index=None,
105
- )
106
-
107
- if "chat_engine" not in st.session_state: # Initialize the query engine
108
- st.session_state["chat_engine"] = index.as_chat_engine(
109
- chat_mode="context", verbose=True
110
  )
 
111
 
112
- for message in st.session_state["messages"]: # Display the prior chat messages
113
- with st.chat_message(message["role"]):
114
- st.write(message["content"])
115
-
116
- # To avoid duplicated display of answered pill questions each rerun
117
- if selected and selected not in st.session_state.get(
118
- "displayed_pill_questions", set()
119
- ):
120
- st.session_state.setdefault("displayed_pill_questions", set()).add(selected)
121
- with st.chat_message("user"):
122
- st.write(selected)
123
- with st.chat_message("assistant"):
124
- response = st.session_state["chat_engine"].stream_chat(selected)
125
- response_str = ""
126
- response_container = st.empty()
127
- for token in response.response_gen:
128
- response_str += token
129
- response_container.write(response_str)
130
- add_to_message_history("user", selected)
131
- add_to_message_history("assistant", response)
132
-
133
- if prompt := st.chat_input(
134
- "Your question"
135
- ): # Prompt for user input and save to chat history
136
- add_to_message_history("user", prompt)
137
-
138
- # Display the new question immediately after it is entered
139
- with st.chat_message("user"):
140
- st.write(prompt)
141
-
142
- # If last message is not from assistant, generate a new response
143
- # if st.session_state["messages"][-1]["role"] != "assistant":
144
- with st.chat_message("assistant"):
145
- response = st.session_state["chat_engine"].stream_chat(prompt)
146
- response_str = ""
147
- response_container = st.empty()
148
- for token in response.response_gen:
149
- response_str += token
150
- response_container.write(response_str)
151
- # st.write(response.response)
152
- add_to_message_history("assistant", response.response)
153
-
154
- # Save the state of the generator
155
- st.session_state["response_gen"] = response.response_gen
156
 
157
 
158
  if __name__ == "__main__":
159
- StreamlitChatPack(run_from_main=True).run()
 
1
+ from typing import Dict, Any, List, Tuple, Optional
 
 
2
 
3
+ from llama_index.llama_pack.base import BaseLlamaPack
4
+ from llama_index.llms import OpenAI
5
+ from llama_index.agent import ReActAgent
6
 
7
+ from llama_hub.tools.arxiv import ArxivToolSpec
8
+ from llama_hub.tools.wikipedia import WikipediaToolSpec
9
+ import functools
10
 
11
+ from io import StringIO
12
+ import sys
13
 
14
+
15
+ SUPPORTED_TOOLS = {
16
+ "arxiv_search_tool": ArxivToolSpec,
17
+ "wikipedia": WikipediaToolSpec,
18
+ }
 
 
19
 
20
 
21
+ class Capturing(list):
22
+ """To capture the stdout from ReActAgent.chat with verbose=True. Taken from
23
+ https://stackoverflow.com/questions/16571150/\
24
+ how-to-capture-stdout-output-from-a-python-function-call
25
+ """
26
+
27
+ def __enter__(self):
28
+ self._stdout = sys.stdout
29
+ sys.stdout = self._stringio = StringIO()
30
+ return self
31
+
32
+ def __exit__(self, *args):
33
+ self.extend(self._stringio.getvalue().splitlines())
34
+ del self._stringio # free up some memory
35
+ sys.stdout = self._stdout
36
+
37
+
38
+ class GradioReActAgentPack(BaseLlamaPack):
39
+ """Gradio chatbot to chat with a ReActAgent pack."""
40
 
41
  def __init__(
42
  self,
43
+ tools_list: Optional[List[str]] = [k for k in SUPPORTED_TOOLS.keys()],
 
44
  **kwargs: Any,
45
  ) -> None:
46
  """Init params."""
47
+ try:
48
+ from ansi2html import Ansi2HTMLConverter
49
+ except ImportError:
50
+ raise ImportError("Please install ansi2html via `pip install ansi2html`")
51
+
52
+ tools = []
53
+ for t in tools_list:
54
+ try:
55
+ tools.append(SUPPORTED_TOOLS[t]())
56
+ except KeyError:
57
+ raise KeyError(f"Tool {t} is not supported.")
58
+ self.tools = tools
59
+
60
+ self.llm = OpenAI(model="gpt-4-1106-preview", max_tokens=2000)
61
+ self.agent = ReActAgent.from_tools(
62
+ tools=functools.reduce(
63
+ lambda x, y: x.to_tool_list() + y.to_tool_list(), self.tools
64
+ ),
65
+ llm=self.llm,
66
+ verbose=True,
67
+ )
68
 
69
+ self.thoughts = ""
70
+ self.conv = Ansi2HTMLConverter()
71
 
72
  def get_modules(self) -> Dict[str, Any]:
73
  """Get modules."""
74
+ return {"agent": self.agent, "llm": self.llm, "tools": self.tools}
75
+
76
+ def _handle_user_message(self, user_message, history):
77
+ """Handle the user submitted message. Clear message box, and append
78
+ to the history."""
79
+ return "", history + [(user_message, "")]
80
+
81
+ def _generate_response(
82
+ self, chat_history: List[Tuple[str, str]]
83
+ ) -> Tuple[str, List[Tuple[str, str]]]:
84
+ """Generate the response from agent, and capture the stdout of the
85
+ ReActAgent's thoughts.
86
+ """
87
+ with Capturing() as output:
88
+ response = self.agent.stream_chat(chat_history[-1][0])
89
+ ansi = "\n========\n".join(output)
90
+ html_output = self.conv.convert(ansi)
91
+ for token in response.response_gen:
92
+ chat_history[-1][1] += token
93
+ yield chat_history, str(html_output)
94
+
95
+ def _reset_chat(self) -> Tuple[str, str]:
96
+ """Reset the agent's chat history. And clear all dialogue boxes."""
97
+ # clear agent history
98
+ self.agent.reset()
99
+ return "", "", "" # clear textboxes
100
 
101
  def run(self, *args: Any, **kwargs: Any) -> Any:
102
  """Run the pipeline."""
103
+ import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
 
105
+ demo = gr.Blocks(
106
+ theme="gstaff/xkcd",
107
+ css="#box { height: 420px; overflow-y: scroll !important}",
108
  )
109
+ with demo:
110
+ gr.Markdown(
111
+ "# Gradio ReActAgent Powered by LlamaIndex and LlamaHub 🦙\n"
112
+ "This Gradio app is powered by LlamaIndex's `ReActAgent` with\n"
113
+ "OpenAI's GPT-4-Turbo as the LLM. The tools are listed below.\n"
114
+ "## Tools\n"
115
+ "- [ArxivToolSpec](https://llamahub.ai/l/tools-arxiv)\n"
116
+ "- [WikipediaToolSpec](https://llamahub.ai/l/tools-wikipedia)"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
  )
118
+ with gr.Row():
119
+ chat_window = gr.Chatbot(
120
+ label="Message History",
121
+ scale=3,
122
+ )
123
+ console = gr.HTML(elem_id="box")
124
+ with gr.Row():
125
+ message = gr.Textbox(label="Write A Message", scale=4)
126
+ clear = gr.ClearButton()
127
+
128
+ message.submit(
129
+ self._handle_user_message,
130
+ [message, chat_window],
131
+ [message, chat_window],
132
+ queue=False,
133
+ ).then(
134
+ self._generate_response,
135
+ chat_window,
136
+ [chat_window, console],
137
  )
138
+ clear.click(self._reset_chat, None, [message, chat_window, console])
139
 
140
+ demo.launch(server_name="0.0.0.0", server_port=8080)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
 
142
 
143
  if __name__ == "__main__":
144
+ GradioReActAgentPack(run_from_main=True).run()