File size: 8,002 Bytes
b505cc3
 
 
 
5c9ea55
e54b3e0
5c9ea55
 
d54eee9
 
e54b3e0
 
da0018b
d54eee9
 
 
 
 
749a763
d54eee9
5c9ea55
437c715
9376584
 
5c9ea55
 
7c5594b
b505cc3
7c5594b
5c9ea55
7c5594b
 
 
5c9ea55
 
 
 
 
7c5594b
5c9ea55
 
 
 
 
 
437c715
5c9ea55
437c715
5c9ea55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fd8ca06
 
ba1f3e2
fd8ca06
 
 
 
 
749a763
 
 
528a449
749a763
 
 
 
fd8ca06
 
 
 
 
7c3b5b3
fd8ca06
 
 
 
 
7c5594b
fd8ca06
 
7c5594b
 
5c9ea55
7c5594b
 
 
 
 
 
 
 
 
 
5c9ea55
0c14e18
ba1f3e2
 
 
 
 
0c14e18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5c9ea55
aac3522
d54eee9
0c14e18
d54eee9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aac3522
 
 
e54b3e0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
import logging
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)

import json
import time
import traceback

import openai
import requests
import streamlit as st

import utils


SEED = 42     

def get_client():
    return openai.OpenAI(api_key = utils.OPENAI_API_KEY,organization=utils.OPENAI_ORGANIZATION_ID)

def getListOfCompanies(query, filters = {}):
    country_filters = filters['country'] if 'country' in filters else st.session_state.country
    st.session_state.db_search_results = utils.search_index(query, st.session_state.top_k, st.session_state.region, country_filters, st.session_state.retriever, st.session_state.index_namespace)
    descriptions = "\n".join([f"Description of company \"{res['name']}\":  {res['data']['Summary']}.\n" for res in st.session_state.db_search_results[:20] if 'Summary' in res['data']])
    return descriptions

def report_error(txt):
    logger.debug(f"\nError: \n{txt}")

def wait_for_response(thread, run):
    timeout = 60    #timeout in seconds
    started = time.time()
    while True and time.time()-started<timeout:
        # Retrieve the run status
        run_status = st.session_state.openai_client.beta.threads.runs.retrieve(
            thread_id=thread.id,
            run_id=run.id
        )
        print(f"Run status: {run_status.status}")
        # Check and print the step details
        run_steps = st.session_state.openai_client.beta.threads.runs.steps.list(
            thread_id=thread.id,
            run_id=run.id
        )
        for step in run_steps.data:
            #print(step)
            if step.type == 'tool_calls':
                print(f"\n--------------------\nTool {step.type} invoked.\n--------------------\n")

            # If step involves code execution, print the code
            if step.type == 'code_interpreter':
                print(f"Python Code Executed: {step.step_details['code_interpreter']['input']}")

        if run_status.status == 'completed':
            # Retrieve all messages from the thread
            messages = st.session_state.openai_client.beta.threads.messages.list(
                thread_id=thread.id
            ) 

            # Print all messages from the thread
            for msg in messages.data:
                role = msg.role
                content = msg.content[0].text.value
                print(f"{role.capitalize()}: {content}")
            return messages
        elif run_status.status in ['queued', 'in_progress']:
            print(f'{run_status.status.capitalize()}... Please wait.')
            time.sleep(1.5)  # Wait before checking again
        elif run_status.status == 'requires_action':
            required_action = run_status.required_action
            if required_action.type == 'submit_tool_outputs':
                print(f"Requires tool outputs: {required_action}")                
                outputs = {}
                for tool_call in required_action.submit_tool_outputs.tool_calls:
                    if tool_call.function.name =="getListOfCompanies":
                        try:
                            args = json.loads(tool_call.function.arguments)
                            res = ''
                            if 'query' in args:
                                print(f"Processing tool_call {tool_call.id}. Calling 'getListOfCompanies with args: {args}" )
                                search_filters = json.loads(args['filters']) if 'filters' in args else {}
                                res = getListOfCompanies(args['query'], search_filters)
                            
                                
                            outputs[tool_call.id] = res
                        except Exception as e:
                            print(f"Error calling tools, {str(e)}")
                            traceback.print_exc()

                tool_outputs=[{"tool_call_id": k, "output": v} for (k,v) in outputs.items()]
                print(f"Finished tools calling: {str(tool_outputs)[:400]}")
                run = st.session_state.openai_client.beta.threads.runs.submit_tool_outputs(
                    thread_id=thread.id,
                    run_id=run.id,
                    tool_outputs=tool_outputs,
                )
                print(f"Required action {run_status.required_action}")
                #return run_status
            else:
                report_error(f"Unknown required action type: {required_action}")
                break
        else:
            report_error(f"Unhandled Run status: {run_status.status}\n\nError: {run_status.last_error}\n")
            break
    if time.time()-started>timeout:
        report_error(f"Wait for response timeout after {timeout}")
    report_error(f"Flow not completed")
    messages = st.session_state.openai_client.beta.threads.messages.list(
        thread_id=thread.id
    )             
    return messages
            

def call_assistant(query, engine="gpt-3.5-turbo"):  #, temp=0, top_p=1.0, max_tokens=4048):
    #Prevent re sending the last message over and over
    print(f"Last query {st.session_state.last_user_query}, current query {query}")
    if st.session_state.last_user_query == query:
        report_error(f"That query '{query}' was just sent. We don't send the same query twice in a row. ")
        return st.session_state.messages
    try:
        thread = st.session_state.assistant_thread
        assistant_id = st.session_state.assistant_id
        message = st.session_state.openai_client.beta.threads.messages.create(
            thread.id,
            role="user",
            content=query,
        )            
        run = st.session_state.openai_client.beta.threads.runs.create(
            thread_id=thread.id,
            assistant_id=assistant_id,
        )
        messages = wait_for_response(thread, run)            

        print(f"====================\nOpen AI response\n {str(messages)[:1000]}\n====================\n")
        return messages
        # text = ""
        # for message in messages:
        #     print(message)
        #     text = text + "\n" + message.content[0].text.value
        # return text
    except Exception as e:
    #except openai.error.OpenAIError as e:
        print(f"An error occurred: {str(e)}")
        

def call_openai(prompt, engine="gpt-3.5-turbo", temp=0, top_p=1.0, max_tokens=4048):
    if st.session_state.report_type=="assistant":
        raise Exception("use call_assistant instead of call_openai")        
    else:        
        try:
            response = st.session_state.openai_client.chat.completions.create(
                model=engine,
                messages=st.session_state.messages + [{"role": "user", "content": prompt}],
                temperature=temp,
                seed = SEED,
                max_tokens=max_tokens
            )
            print(f"====================\nOpen AI response\n {response}\n====================\n")
            text = response.choices[0].message.content.strip()
            return text
        except Exception as e:
        #except openai.error.OpenAIError as e:
            print(f"An error occurred: {str(e)}")
    return "Failed to generate a response."


def send_message(role, content):
    message = st.session_state.openai_client.beta.threads.messages.create(
        thread_id=st.session_state.assistant_thread.id,
        role=role,
        content=content
    )

def start_conversation():
    st.session_state.assistant_thread = st.session_state.openai_client.beta.threads.create()    
    

def run_assistant():
    run = st.session_state.openai_client.beta.threads.runs.create(
        thread_id=st.session_state.assistant_thread.id,
        assistant_id=st.session_state.assistant.id,
    )
    while run.status == "queued" or run.status == "in_progress":
        run = st.session_state.openai_client.beta.threads.runs.retrieve(
            thread_id=st.session_state.assistant_thread.id,
            run_id=run.id,
        )
        time.sleep(0.5)
    return run