Spaces:
Running
Running
import logging | |
logger = logging.getLogger(__name__) | |
logger.setLevel(logging.DEBUG) | |
import json | |
import time | |
import traceback | |
import openai | |
import requests | |
import streamlit as st | |
import utils | |
SEED = 42 | |
def get_client(): | |
return openai.OpenAI(api_key = utils.OPENAI_API_KEY,organization=utils.OPENAI_ORGANIZATION_ID) | |
def getListOfCompanies(query, filters = {}): | |
country_filters = filters['country'] if 'country' in filters else st.session_state.country | |
st.session_state.db_search_results = utils.search_index(query, st.session_state.top_k, st.session_state.region, country_filters, st.session_state.retriever, st.session_state.index_namespace) | |
descriptions = "\n".join([f"Description of company \"{res['name']}\": {res['data']['Summary']}.\n" for res in st.session_state.db_search_results[:20] if 'Summary' in res['data']]) | |
return descriptions | |
def report_error(txt): | |
logger.debug(f"\nError: \n{txt}") | |
def wait_for_response(thread, run): | |
timeout = 60 #timeout in seconds | |
started = time.time() | |
while True and time.time()-started<timeout: | |
# Retrieve the run status | |
run_status = st.session_state.openai_client.beta.threads.runs.retrieve( | |
thread_id=thread.id, | |
run_id=run.id | |
) | |
print(f"Run status: {run_status.status}") | |
# Check and print the step details | |
run_steps = st.session_state.openai_client.beta.threads.runs.steps.list( | |
thread_id=thread.id, | |
run_id=run.id | |
) | |
for step in run_steps.data: | |
#print(step) | |
if step.type == 'tool_calls': | |
print(f"\n--------------------\nTool {step.type} invoked.\n--------------------\n") | |
# If step involves code execution, print the code | |
if step.type == 'code_interpreter': | |
print(f"Python Code Executed: {step.step_details['code_interpreter']['input']}") | |
if run_status.status == 'completed': | |
# Retrieve all messages from the thread | |
messages = st.session_state.openai_client.beta.threads.messages.list( | |
thread_id=thread.id | |
) | |
# Print all messages from the thread | |
for msg in messages.data: | |
role = msg.role | |
content = msg.content[0].text.value | |
print(f"{role.capitalize()}: {content}") | |
return messages | |
elif run_status.status in ['queued', 'in_progress']: | |
print(f'{run_status.status.capitalize()}... Please wait.') | |
time.sleep(1.5) # Wait before checking again | |
elif run_status.status == 'requires_action': | |
required_action = run_status.required_action | |
if required_action.type == 'submit_tool_outputs': | |
print(f"Requires tool outputs: {required_action}") | |
outputs = {} | |
for tool_call in required_action.submit_tool_outputs.tool_calls: | |
if tool_call.function.name =="getListOfCompanies": | |
try: | |
args = json.loads(tool_call.function.arguments) | |
res = '' | |
if 'query' in args: | |
print(f"Processing tool_call {tool_call.id}. Calling 'getListOfCompanies with args: {args}" ) | |
search_filters = json.loads(args['filters']) if 'filters' in args else {} | |
res = getListOfCompanies(args['query'], search_filters) | |
outputs[tool_call.id] = res | |
except Exception as e: | |
print(f"Error calling tools, {str(e)}") | |
traceback.print_exc() | |
tool_outputs=[{"tool_call_id": k, "output": v} for (k,v) in outputs.items()] | |
print(f"Finished tools calling: {str(tool_outputs)[:400]}") | |
run = st.session_state.openai_client.beta.threads.runs.submit_tool_outputs( | |
thread_id=thread.id, | |
run_id=run.id, | |
tool_outputs=tool_outputs, | |
) | |
print(f"Required action {run_status.required_action}") | |
#return run_status | |
else: | |
report_error(f"Unknown required action type: {required_action}") | |
break | |
else: | |
report_error(f"Unhandled Run status: {run_status.status}\n\nError: {run_status.last_error}\n") | |
break | |
if time.time()-started>timeout: | |
report_error(f"Wait for response timeout after {timeout}") | |
report_error(f"Flow not completed") | |
messages = st.session_state.openai_client.beta.threads.messages.list( | |
thread_id=thread.id | |
) | |
return messages | |
def call_assistant(query, engine="gpt-3.5-turbo"): #, temp=0, top_p=1.0, max_tokens=4048): | |
#Prevent re sending the last message over and over | |
print(f"Last query {st.session_state.last_user_query}, current query {query}") | |
if st.session_state.last_user_query == query: | |
report_error(f"That query '{query}' was just sent. We don't send the same query twice in a row. ") | |
return st.session_state.messages | |
try: | |
thread = st.session_state.assistant_thread | |
assistant_id = st.session_state.assistant_id | |
message = st.session_state.openai_client.beta.threads.messages.create( | |
thread.id, | |
role="user", | |
content=query, | |
) | |
run = st.session_state.openai_client.beta.threads.runs.create( | |
thread_id=thread.id, | |
assistant_id=assistant_id, | |
) | |
messages = wait_for_response(thread, run) | |
print(f"====================\nOpen AI response\n {str(messages)[:1000]}\n====================\n") | |
return messages | |
# text = "" | |
# for message in messages: | |
# print(message) | |
# text = text + "\n" + message.content[0].text.value | |
# return text | |
except Exception as e: | |
#except openai.error.OpenAIError as e: | |
print(f"An error occurred: {str(e)}") | |
def call_openai(prompt, engine="gpt-3.5-turbo", temp=0, top_p=1.0, max_tokens=4048): | |
if st.session_state.report_type=="assistant": | |
raise Exception("use call_assistant instead of call_openai") | |
else: | |
try: | |
response = st.session_state.openai_client.chat.completions.create( | |
model=engine, | |
messages=st.session_state.messages + [{"role": "user", "content": prompt}], | |
temperature=temp, | |
seed = SEED, | |
max_tokens=max_tokens | |
) | |
print(f"====================\nOpen AI response\n {response}\n====================\n") | |
text = response.choices[0].message.content.strip() | |
return text | |
except Exception as e: | |
#except openai.error.OpenAIError as e: | |
print(f"An error occurred: {str(e)}") | |
return "Failed to generate a response." | |
def send_message(role, content): | |
message = st.session_state.openai_client.beta.threads.messages.create( | |
thread_id=st.session_state.assistant_thread.id, | |
role=role, | |
content=content | |
) | |
def start_conversation(): | |
st.session_state.assistant_thread = st.session_state.openai_client.beta.threads.create() | |
def run_assistant(): | |
run = st.session_state.openai_client.beta.threads.runs.create( | |
thread_id=st.session_state.assistant_thread.id, | |
assistant_id=st.session_state.assistant.id, | |
) | |
while run.status == "queued" or run.status == "in_progress": | |
run = st.session_state.openai_client.beta.threads.runs.retrieve( | |
thread_id=st.session_state.assistant_thread.id, | |
run_id=run.id, | |
) | |
time.sleep(0.5) | |
return run | |