Spaces:
Sleeping
Sleeping
# %%writefile app.py | |
from setup_code import * # This imports everything from setup_code.py | |
general_greeting_num = 0 | |
general_question_num = 1 | |
machine_learning_num = 2 | |
python_code_num = 3 | |
obnoxious_num = 4 | |
default_num = 5 | |
query_classes = {'[General greeting]': general_greeting_num, | |
'[General question]': general_question_num, | |
'[Question about Machine Learning]': machine_learning_num, | |
'[Question about Python code]' : python_code_num, | |
'[Obnoxious statement]': obnoxious_num | |
} | |
query_classes_text = ", ".join(query_classes.keys()) | |
class Classify_Agent: | |
def __init__(self, openai_client) -> None: | |
# TODO: Initialize the client and prompt for the Obnoxious_Agent | |
self.openai_client = openai_client | |
def classify_query(self, query): | |
prompt = f"Please classify this query in angle brackets <{query}> as one of the following in square brackets only: {query_classes_text}." | |
classification_response = get_completion(self.openai_client, prompt) | |
if classification_response != None and classification_response in query_classes.keys(): | |
query_class = query_classes.get(classification_response, default_num) | |
# st.write(f"query <{query}>: {classification_response}") | |
return query_classes.get(classification_response, default_num) | |
else: | |
# st.write(f"query <{query}>: {classification_response}") | |
return default_num | |
class Relevant_Documents_Agent: | |
def __init__(self, openai_client) -> None: | |
# TODO: Initialize the Relevant_Documents_Agent | |
self.client = openai_client | |
def get_relevance(self, conversation) -> str: | |
pass | |
def get_relevant_docs(self, conversation, docs) -> str: # uses Query Agent to get relevant docs | |
pass | |
def is_relevant(self, matches_text, user_query_plus_conversation) -> bool: | |
prompt = f"Please confirm that the text in angle brackets: <{matches_text}>, is relevant to the text in double square brackets: [[{user_query_plus_conversation}]]. Return Yes or No" | |
response = get_completion(self.client, prompt) | |
return is_Yes(response) | |
class Query_Agent: | |
def __init__(self, pinecone_index, pinecone_index_python, openai_client, embeddings) -> None: | |
# TODO: Initialize the Query_Agent agent | |
self.pinecone_index = pinecone_index | |
self.pinecone_index_python = pinecone_index_python | |
self.openai_client = openai_client | |
self.embeddings = embeddings | |
def get_openai_embedding(self, text, model="text-embedding-ada-002"): | |
text = text.replace("\n", " ") | |
return self.openai_client.embeddings.create(input=[text], model=model).data[0].embedding | |
def query_vector_store(self, query, index=None, k=5) -> str: | |
if index == None: | |
index = self.pinecone_index | |
query_embedding = self.get_openai_embedding(query) | |
def get_namespace(index): | |
stat = index.describe_index_stats() | |
stat_dict_key = stat['namespaces'].keys() | |
stat_dict_key_list = list(stat_dict_key) | |
first_key = stat_dict_key_list[0] | |
return first_key | |
ns = get_namespace(index) | |
matches_text = get_top_k_text(index.query( | |
namespace=ns, | |
top_k=k, | |
vector=query_embedding, | |
include_values=True, | |
include_metadata=True | |
) | |
) | |
return matches_text | |
class Answering_Agent: | |
def __init__(self, openai_client) -> None: | |
# TODO: Initialize the Answering_Agent | |
self.client = openai_client | |
def generate_response(self, query, docs, conv_history, selected_mode): | |
# TODO: Generate a response to the user's query | |
prompt_for_gpt = f"Based on this text in angle brackets: <{docs}>, please summarize a response to this query: {query} in the context of this conversation: {conv_history}. Please use language appropriate for a {selected_mode}." | |
return get_completion(self.client, prompt_for_gpt) | |
def generate_image(self, text): | |
caption_prompt = f"Based on this text, repeated here in double square brackets for your reference: [[{text}]], please generate a simple caption that I can use with dall-e to generate an instructional image." | |
caption_text = get_completion(self.client, caption_prompt) | |
#st.write(caption_text) | |
image = Head_Agent.text_to_image(self.client, caption_text) | |
return image | |
class Head_Agent: | |
def __init__(self, openai_key, pinecone_key) -> None: | |
# TODO: Initialize the Head_Agent | |
self.openai_key = openai_key | |
self.pinecone_key = pinecone_key | |
self.selected_mode = "" | |
self.openai_client = OpenAI(api_key=self.openai_key) | |
self.pc = Pinecone(api_key=self.pinecone_key) | |
self.pinecone_index = self.pc.Index("index-600") | |
self.pinecone_index_python = self.pc.Index("index-py-files") | |
self.setup_sub_agents() | |
def setup_sub_agents(self): | |
# TODO: Setup the sub-agents | |
self.classify_agent = Classify_Agent(self.openai_client) | |
self.query_agent = Query_Agent(self.pinecone_index, self.pinecone_index_python, self.openai_client, None) # Pass embeddings if needed | |
self.answering_agent = Answering_Agent(self.openai_client) | |
self.relevant_documents_agent = Relevant_Documents_Agent(self.openai_client) | |
def process_query_response(self, user_query, query_topic): | |
# Retrieve the history related to the query_topic | |
conversation = [] | |
index = self.pinecone_index | |
if query_topic == "ml": | |
conversation = Head_Agent.get_history_about('ml') | |
elif query_topic == 'python': | |
conversation = Head_Agent.get_history_about('python') | |
index = self.pinecone_index_python | |
# get matches from Query_Agent, which uses Pinecone | |
user_query_plus_conversation = f"The current query is: {user_query}" | |
if len(conversation) > 0: | |
conversation_text = "\n".join(conversation) | |
user_query_plus_conversation += f'The current conversation is: {conversation_text}' | |
# st.write(user_query_plus_conversation) | |
matches_text = self.query_agent.query_vector_store(user_query_plus_conversation, index) | |
if self.relevant_documents_agent.is_relevant(matches_text, user_query_plus_conversation): | |
#maybe here we can ask GPT to make up an answer if there is no match | |
response = self.answering_agent.generate_response(user_query, matches_text, conversation, self.selected_mode) | |
else: | |
response = "Sorry, I don't have relevant information to answer that query." | |
return response | |
def get_conversation(): | |
# ... (code for getting conversation history) | |
return Head_Agent.get_history_about() | |
def get_history_about(topic=None): | |
history = [] | |
for message in st.session_state.messages: | |
role = message["role"] | |
content = message["content"] | |
if topic == None: | |
if role == "user": | |
history.append(f"{content} ") | |
else: | |
if message["topic"] == topic: | |
history.append(f"{content} ") | |
# st.write(f"user history in get_conversation is {history}") | |
if history != None: | |
history = history[-2:] | |
return history | |
def text_to_image(openai_client, text): | |
response = openai_client.images.generate( | |
model="dall-e-3", | |
prompt = text, | |
n=1, | |
size="1024x1024" | |
) | |
image_url = response.data[0].url | |
with urllib.request.urlopen(image_url) as image_url: | |
img = Image.open(BytesIO(image_url.read())) | |
return img | |
def main_loop_1(self): | |
# TODO: Run the main loop for the chatbot | |
st.title("Mini Project 2: Streamlit Chatbot") | |
# Check for existing session state variables | |
if "openai_model" not in st.session_state: | |
# ... (initialize model) | |
# st.session_state.openai_model = openai_client #'GPT-3.5-turbo' | |
st.session_state.openai_model = 'gpt-3.5-turbo' | |
if "messages" not in st.session_state: | |
# ... (initialize messages) | |
st.session_state.messages = [] | |
# Define the selection options | |
modes = ['1st grade student', 'middle school student', 'high school student', 'college student', 'grad student'] | |
# Use st.selectbox to let the user select a mode | |
self.selected_mode = st.selectbox("Select your education level:", modes) | |
# Display existing chat messages | |
# ... (code for displaying messages) | |
for message in st.session_state.messages: | |
if message["role"] == "assistant": | |
with st.chat_message("assistant"): | |
st.write(message["content"]) | |
if message['image'] != None: | |
st.image(message['image']) | |
else: | |
with st.chat_message("user"): | |
st.write(message["content"]) | |
# Wait for user input | |
if user_query := st.chat_input("What would you like to chat about?"): | |
# # ... (append user message to messages) | |
# ... (display user message) | |
with st.chat_message("user"): | |
st.write(user_query) | |
# Generate AI response | |
with st.chat_message("assistant"): | |
# ... (send request to OpenAI API) | |
response = "" | |
topic = None | |
image = None | |
hasImage = False | |
# Get the current conversation with new user query to check for users' intension | |
conversation = self.get_conversation() | |
user_query_plus_conversation = f"The current query is: {user_query}. The current conversation is: {conversation}" | |
classify_query = self.classify_agent.classify_query(user_query_plus_conversation) | |
if classify_query == general_greeting_num: | |
response = "How can I assist you today?" | |
elif classify_query == general_question_num: | |
response = "Please ask a question about Machine Learning or Python Code." | |
elif classify_query == machine_learning_num: | |
# answering agent will 1. call query agent te get matches from pinecone, 2. verify the matches r relevant, 3. generate response | |
response = self.process_query_response(user_query, 'ml') | |
# answering agent will generate an image | |
if not contains_sorry(response): | |
image = self.answering_agent.generate_image(response) | |
hasImage = True | |
topic = "ml" | |
elif classify_query == python_code_num: | |
response = self.process_query_response(user_query, 'python') | |
# answering agent will generate an image | |
if not contains_sorry(response): | |
image = self.answering_agent.generate_image(response) | |
hasImage = True | |
topic = "python" | |
elif classify_query == obnoxious_num: | |
response = "Please dont be obnoxious." | |
elif classify_query == default_num: | |
response = "I'm not sure how to respond to that." | |
else: | |
response = "I'm not sure how to respond to that." | |
# ... (get AI response and display it) | |
st.write(response) | |
if hasImage: | |
st.image(image) | |
# Test moving append user_query down here: | |
st.session_state.messages.append({"role": "user", "content": user_query, "topic": topic, "image": None}) | |
# ... (append AI response to messages) | |
st.session_state.messages.append({"role": "assistant", "content": response, "topic": topic, "image": image}) | |
if __name__ == "__main__": | |
head_agent = Head_Agent(OPENAI_KEY, pc_apikey) | |
head_agent.main_loop_1() |