semsearch / openai_utils.py
hanoch@raized.ai
working version
b505cc3
raw
history blame
8 kB
import logging
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
import json
import time
import traceback
import openai
import requests
import streamlit as st
import utils
SEED = 42
def get_client():
return openai.OpenAI(api_key = utils.OPENAI_API_KEY,organization=utils.OPENAI_ORGANIZATION_ID)
def getListOfCompanies(query, filters = {}):
country_filters = filters['country'] if 'country' in filters else st.session_state.country
st.session_state.db_search_results = utils.search_index(query, st.session_state.top_k, st.session_state.region, country_filters, st.session_state.retriever, st.session_state.index_namespace)
descriptions = "\n".join([f"Description of company \"{res['name']}\": {res['data']['Summary']}.\n" for res in st.session_state.db_search_results[:20] if 'Summary' in res['data']])
return descriptions
def report_error(txt):
logger.debug(f"\nError: \n{txt}")
def wait_for_response(thread, run):
timeout = 60 #timeout in seconds
started = time.time()
while True and time.time()-started<timeout:
# Retrieve the run status
run_status = st.session_state.openai_client.beta.threads.runs.retrieve(
thread_id=thread.id,
run_id=run.id
)
print(f"Run status: {run_status.status}")
# Check and print the step details
run_steps = st.session_state.openai_client.beta.threads.runs.steps.list(
thread_id=thread.id,
run_id=run.id
)
for step in run_steps.data:
#print(step)
if step.type == 'tool_calls':
print(f"\n--------------------\nTool {step.type} invoked.\n--------------------\n")
# If step involves code execution, print the code
if step.type == 'code_interpreter':
print(f"Python Code Executed: {step.step_details['code_interpreter']['input']}")
if run_status.status == 'completed':
# Retrieve all messages from the thread
messages = st.session_state.openai_client.beta.threads.messages.list(
thread_id=thread.id
)
# Print all messages from the thread
for msg in messages.data:
role = msg.role
content = msg.content[0].text.value
print(f"{role.capitalize()}: {content}")
return messages
elif run_status.status in ['queued', 'in_progress']:
print(f'{run_status.status.capitalize()}... Please wait.')
time.sleep(1.5) # Wait before checking again
elif run_status.status == 'requires_action':
required_action = run_status.required_action
if required_action.type == 'submit_tool_outputs':
print(f"Requires tool outputs: {required_action}")
outputs = {}
for tool_call in required_action.submit_tool_outputs.tool_calls:
if tool_call.function.name =="getListOfCompanies":
try:
args = json.loads(tool_call.function.arguments)
res = ''
if 'query' in args:
print(f"Processing tool_call {tool_call.id}. Calling 'getListOfCompanies with args: {args}" )
search_filters = json.loads(args['filters']) if 'filters' in args else {}
res = getListOfCompanies(args['query'], search_filters)
outputs[tool_call.id] = res
except Exception as e:
print(f"Error calling tools, {str(e)}")
traceback.print_exc()
tool_outputs=[{"tool_call_id": k, "output": v} for (k,v) in outputs.items()]
print(f"Finished tools calling: {str(tool_outputs)[:400]}")
run = st.session_state.openai_client.beta.threads.runs.submit_tool_outputs(
thread_id=thread.id,
run_id=run.id,
tool_outputs=tool_outputs,
)
print(f"Required action {run_status.required_action}")
#return run_status
else:
report_error(f"Unknown required action type: {required_action}")
break
else:
report_error(f"Unhandled Run status: {run_status.status}\n\nError: {run_status.last_error}\n")
break
if time.time()-started>timeout:
report_error(f"Wait for response timeout after {timeout}")
report_error(f"Flow not completed")
messages = st.session_state.openai_client.beta.threads.messages.list(
thread_id=thread.id
)
return messages
def call_assistant(query, engine="gpt-3.5-turbo"): #, temp=0, top_p=1.0, max_tokens=4048):
#Prevent re sending the last message over and over
print(f"Last query {st.session_state.last_user_query}, current query {query}")
if st.session_state.last_user_query == query:
report_error(f"That query '{query}' was just sent. We don't send the same query twice in a row. ")
return st.session_state.messages
try:
thread = st.session_state.assistant_thread
assistant_id = st.session_state.assistant_id
message = st.session_state.openai_client.beta.threads.messages.create(
thread.id,
role="user",
content=query,
)
run = st.session_state.openai_client.beta.threads.runs.create(
thread_id=thread.id,
assistant_id=assistant_id,
)
messages = wait_for_response(thread, run)
print(f"====================\nOpen AI response\n {str(messages)[:1000]}\n====================\n")
return messages
# text = ""
# for message in messages:
# print(message)
# text = text + "\n" + message.content[0].text.value
# return text
except Exception as e:
#except openai.error.OpenAIError as e:
print(f"An error occurred: {str(e)}")
def call_openai(prompt, engine="gpt-3.5-turbo", temp=0, top_p=1.0, max_tokens=4048):
if st.session_state.report_type=="assistant":
raise Exception("use call_assistant instead of call_openai")
else:
try:
response = st.session_state.openai_client.chat.completions.create(
model=engine,
messages=st.session_state.messages + [{"role": "user", "content": prompt}],
temperature=temp,
seed = SEED,
max_tokens=max_tokens
)
print(f"====================\nOpen AI response\n {response}\n====================\n")
text = response.choices[0].message.content.strip()
return text
except Exception as e:
#except openai.error.OpenAIError as e:
print(f"An error occurred: {str(e)}")
return "Failed to generate a response."
def send_message(role, content):
message = st.session_state.openai_client.beta.threads.messages.create(
thread_id=st.session_state.assistant_thread.id,
role=role,
content=content
)
def start_conversation():
st.session_state.assistant_thread = st.session_state.openai_client.beta.threads.create()
def run_assistant():
run = st.session_state.openai_client.beta.threads.runs.create(
thread_id=st.session_state.assistant_thread.id,
assistant_id=st.session_state.assistant.id,
)
while run.status == "queued" or run.status == "in_progress":
run = st.session_state.openai_client.beta.threads.runs.retrieve(
thread_id=st.session_state.assistant_thread.id,
run_id=run.id,
)
time.sleep(0.5)
return run