long1104 commited on
Commit
6e58dd2
·
verified ·
1 Parent(s): 72e913d

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +294 -0
app.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # %%writefile app.py
2
+ from setup_code import * # This imports everything from setup_code.py
3
+
4
+ general_greeting_num = 0
5
+ general_question_num = 1
6
+ machine_learning_num = 2
7
+ python_code_num = 3
8
+ obnoxious_num = 4
9
+ default_num = 5
10
+
11
+ query_classes = {'[General greeting]': general_greeting_num,
12
+ '[General question]': general_question_num,
13
+ '[Question about Machine Learning]': machine_learning_num,
14
+ '[Question about Python code]' : python_code_num,
15
+ '[Obnoxious statement]': obnoxious_num
16
+ }
17
+
18
+ query_classes_text = ", ".join(query_classes.keys())
19
+
20
+ class Classify_Agent:
21
+ def __init__(self, openai_client) -> None:
22
+ # TODO: Initialize the client and prompt for the Obnoxious_Agent
23
+ self.openai_client = openai_client
24
+
25
+ def classify_query(self, query):
26
+ prompt = f"Please classify this query in angle brackets <{query}> as one of the following in square brackets only: {query_classes_text}."
27
+ classification_response = get_completion(self.openai_client, prompt)
28
+
29
+ if classification_response != None and classification_response in query_classes.keys():
30
+ query_class = query_classes.get(classification_response, default_num)
31
+ # st.write(f"query <{query}>: {classification_response}")
32
+
33
+ return query_classes.get(classification_response, default_num)
34
+ else:
35
+ # st.write(f"query <{query}>: {classification_response}")
36
+ return default_num
37
+
38
+ class Relevant_Documents_Agent:
39
+ def __init__(self, openai_client) -> None:
40
+ # TODO: Initialize the Relevant_Documents_Agent
41
+ self.client = openai_client
42
+
43
+ def get_relevance(self, conversation) -> str:
44
+ pass
45
+
46
+ def get_relevant_docs(self, conversation, docs) -> str: # uses Query Agent to get relevant docs
47
+ pass
48
+
49
+ def is_relevant(self, matches_text, user_query_plus_conversation) -> bool:
50
+ prompt = f"Please confirm that the text in angle brackets: <{matches_text}>, is relevant to the text in double square brackets: [[{user_query_plus_conversation}]]. Return Yes or No"
51
+ response = get_completion(self.client, prompt)
52
+
53
+ return is_Yes(response)
54
+
55
+ class Query_Agent:
56
+ def __init__(self, pinecone_index, pinecone_index_python, openai_client, embeddings) -> None:
57
+ # TODO: Initialize the Query_Agent agent
58
+ self.pinecone_index = pinecone_index
59
+ self.pinecone_index_python = pinecone_index_python
60
+ self.openai_client = openai_client
61
+ self.embeddings = embeddings
62
+
63
+ def get_openai_embedding(self, text, model="text-embedding-ada-002"):
64
+ text = text.replace("\n", " ")
65
+ return self.openai_client.embeddings.create(input=[text], model=model).data[0].embedding
66
+
67
+ def query_vector_store(self, query, index=None, k=5) -> str:
68
+ if index == None:
69
+ index = self.pinecone_index
70
+
71
+ query_embedding = self.get_openai_embedding(query)
72
+
73
+ def get_namespace(index):
74
+ stat = index.describe_index_stats()
75
+ stat_dict_key = stat['namespaces'].keys()
76
+
77
+ stat_dict_key_list = list(stat_dict_key)
78
+ first_key = stat_dict_key_list[0]
79
+
80
+ return first_key
81
+
82
+ ns = get_namespace(index)
83
+
84
+ matches_text = get_top_k_text(index.query(
85
+ namespace=ns,
86
+ top_k=k,
87
+ vector=query_embedding,
88
+ include_values=True,
89
+ include_metadata=True
90
+ )
91
+ )
92
+ return matches_text
93
+
94
+ class Answering_Agent:
95
+ def __init__(self, openai_client) -> None:
96
+ # TODO: Initialize the Answering_Agent
97
+ self.client = openai_client
98
+
99
+ def generate_response(self, query, docs, conv_history, selected_mode):
100
+ # TODO: Generate a response to the user's query
101
+ prompt_for_gpt = f"Based on this text in angle brackets: <{docs}>, please summarize a response to this query: {query} in the context of this conversation: {conv_history}. Please use language appropriate for a {selected_mode}."
102
+ return get_completion(self.client, prompt_for_gpt)
103
+
104
+ def generate_image(self, text):
105
+ caption_prompt = f"Based on this text, repeated here in double square brackets for your reference: [[{text}]], please generate a simple caption that I can use with dall-e to generate an instructional image."
106
+ caption_text = get_completion(self.client, caption_prompt)
107
+ #st.write(caption_text)
108
+ image = Head_Agent.text_to_image(self.client, caption_text)
109
+ return image
110
+
111
+ class Head_Agent:
112
+ def __init__(self, openai_key, pinecone_key) -> None:
113
+ # TODO: Initialize the Head_Agent
114
+ self.openai_key = openai_key
115
+ self.pinecone_key = pinecone_key
116
+ self.selected_mode = ""
117
+
118
+ self.openai_client = OpenAI(api_key=self.openai_key)
119
+ self.pc = Pinecone(api_key=self.pinecone_key)
120
+ self.pinecone_index = self.pc.Index("index-600")
121
+ self.pinecone_index_python = self.pc.Index("index-py-files")
122
+
123
+ self.setup_sub_agents()
124
+
125
+ def setup_sub_agents(self):
126
+ # TODO: Setup the sub-agents
127
+ self.classify_agent = Classify_Agent(self.openai_client)
128
+ self.query_agent = Query_Agent(self.pinecone_index, self.pinecone_index_python, self.openai_client, None) # Pass embeddings if needed
129
+ self.answering_agent = Answering_Agent(self.openai_client)
130
+ self.relevant_documents_agent = Relevant_Documents_Agent(self.openai_client)
131
+
132
+ def process_query_response(self, user_query, query_topic):
133
+ # Retrieve the history related to the query_topic
134
+ conversation = []
135
+ index = self.pinecone_index
136
+ if query_topic == "ml":
137
+ conversation = Head_Agent.get_history_about('ml')
138
+ elif query_topic == 'python':
139
+ conversation = Head_Agent.get_history_about('python')
140
+ index = self.pinecone_index_python
141
+
142
+ # get matches from Query_Agent, which uses Pinecone
143
+ user_query_plus_conversation = f"The current query is: {user_query}"
144
+ if len(conversation) > 0:
145
+ conversation_text = "\n".join(conversation)
146
+ user_query_plus_conversation += f'The current conversation is: {conversation_text}'
147
+
148
+ # st.write(user_query_plus_conversation)
149
+ matches_text = self.query_agent.query_vector_store(user_query_plus_conversation, index)
150
+
151
+ if self.relevant_documents_agent.is_relevant(matches_text, user_query_plus_conversation):
152
+ #maybe here we can ask GPT to make up an answer if there is no match
153
+ response = self.answering_agent.generate_response(user_query, matches_text, conversation, self.selected_mode)
154
+ else:
155
+ response = "Sorry, I don't have relevant information to answer that query."
156
+
157
+ return response
158
+
159
+ @staticmethod
160
+ def get_conversation():
161
+ # ... (code for getting conversation history)
162
+ return Head_Agent.get_history_about()
163
+
164
+ @staticmethod
165
+ def get_history_about(topic=None):
166
+ history = []
167
+
168
+ for message in st.session_state.messages:
169
+ role = message["role"]
170
+ content = message["content"]
171
+
172
+ if topic == None:
173
+ if role == "user":
174
+ history.append(f"{content} ")
175
+ else:
176
+ if message["topic"] == topic:
177
+ history.append(f"{content} ")
178
+
179
+ # st.write(f"user history in get_conversation is {history}")
180
+
181
+ if history != None:
182
+ history = history[-2:]
183
+
184
+ return history
185
+
186
+ @staticmethod
187
+ def text_to_image(openai_client, text):
188
+ response = openai_client.images.generate(
189
+ model="dall-e-3",
190
+ prompt = text,
191
+ n=1,
192
+ size="1024x1024"
193
+ )
194
+ image_url = response.data[0].url
195
+ with urllib.request.urlopen(image_url) as image_url:
196
+ img = Image.open(BytesIO(image_url.read()))
197
+
198
+ return img
199
+
200
+ def main_loop_1(self):
201
+ # TODO: Run the main loop for the chatbot
202
+ st.title("Mini Project 2: Streamlit Chatbot")
203
+
204
+ # Check for existing session state variables
205
+ if "openai_model" not in st.session_state:
206
+ # ... (initialize model)
207
+ # st.session_state.openai_model = openai_client #'GPT-3.5-turbo'
208
+ st.session_state.openai_model = 'gpt-3.5-turbo'
209
+
210
+ if "messages" not in st.session_state:
211
+ # ... (initialize messages)
212
+ st.session_state.messages = []
213
+
214
+ # Define the selection options
215
+ modes = ['1st grade student', 'middle school student', 'high school student', 'college student', 'grad student']
216
+
217
+ # Use st.selectbox to let the user select a mode
218
+ self.selected_mode = st.selectbox("Select your education level:", modes)
219
+
220
+ # Display existing chat messages
221
+ # ... (code for displaying messages)
222
+ for message in st.session_state.messages:
223
+ if message["role"] == "assistant":
224
+ with st.chat_message("assistant"):
225
+ st.write(message["content"])
226
+ if message['image'] != None:
227
+ st.image(message['image'])
228
+ else:
229
+ with st.chat_message("user"):
230
+ st.write(message["content"])
231
+
232
+ # Wait for user input
233
+ if user_query := st.chat_input("What would you like to chat about?"):
234
+ # # ... (append user message to messages)
235
+
236
+ # ... (display user message)
237
+ with st.chat_message("user"):
238
+ st.write(user_query)
239
+
240
+ # Generate AI response
241
+ with st.chat_message("assistant"):
242
+ # ... (send request to OpenAI API)
243
+ response = ""
244
+ topic = None
245
+ image = None
246
+ hasImage = False
247
+
248
+ # Get the current conversation with new user query to check for users' intension
249
+ conversation = self.get_conversation()
250
+ user_query_plus_conversation = f"The current query is: {user_query}. The current conversation is: {conversation}"
251
+ classify_query = self.classify_agent.classify_query(user_query_plus_conversation)
252
+
253
+ if classify_query == general_greeting_num:
254
+ response = "How can I assist you today?"
255
+ elif classify_query == general_question_num:
256
+ response = "Please ask a question about Machine Learning or Python Code."
257
+ elif classify_query == machine_learning_num:
258
+ # answering agent will 1. call query agent te get matches from pinecone, 2. verify the matches r relevant, 3. generate response
259
+ response = self.process_query_response(user_query, 'ml')
260
+
261
+ # answering agent will generate an image
262
+ if not contains_sorry(response):
263
+ image = self.answering_agent.generate_image(response)
264
+ hasImage = True
265
+ topic = "ml"
266
+
267
+ elif classify_query == python_code_num:
268
+ response = self.process_query_response(user_query, 'python')
269
+ # answering agent will generate an image
270
+ if not contains_sorry(response):
271
+ image = self.answering_agent.generate_image(response)
272
+ hasImage = True
273
+ topic = "python"
274
+
275
+ elif classify_query == obnoxious_num:
276
+ response = "Please dont be obnoxious."
277
+ elif classify_query == default_num:
278
+ response = "I'm not sure how to respond to that."
279
+ else:
280
+ response = "I'm not sure how to respond to that."
281
+
282
+ # ... (get AI response and display it)
283
+ st.write(response)
284
+ if hasImage:
285
+ st.image(image)
286
+
287
+ # Test moving append user_query down here:
288
+ st.session_state.messages.append({"role": "user", "content": user_query, "topic": topic, "image": None})
289
+ # ... (append AI response to messages)
290
+ st.session_state.messages.append({"role": "assistant", "content": response, "topic": topic, "image": image})
291
+
292
+ if __name__ == "__main__":
293
+ head_agent = Head_Agent(OPENAI_KEY, pc_apikey)
294
+ head_agent.main_loop_1()