import logging
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)

import json
import time
import traceback

import openai
import requests
import streamlit as st

import utils


SEED = 42     

def get_client():
    return openai.OpenAI(api_key = utils.OPENAI_API_KEY,organization=utils.OPENAI_ORGANIZATION_ID)

def getListOfCompanies(query, filters = {}):
    country_filters = filters['country'] if 'country' in filters else st.session_state.country
    st.session_state.db_search_results = utils.search_index(query, st.session_state.top_k, st.session_state.region, country_filters, st.session_state.retriever, st.session_state.index_namespace)
    descriptions = "\n".join([f"Description of company \"{res['name']}\":  {res['data']['Summary']}.\n" for res in st.session_state.db_search_results[:20] if 'Summary' in res['data']])
    return descriptions

def report_error(txt):
    logger.debug(f"\nError: \n{txt}")

def wait_for_response(thread, run):
    timeout = 60    #timeout in seconds
    started = time.time()
    while True and time.time()-started<timeout:
        # Retrieve the run status
        run_status = st.session_state.openai_client.beta.threads.runs.retrieve(
            thread_id=thread.id,
            run_id=run.id
        )
        print(f"Run status: {run_status.status}")
        # Check and print the step details
        run_steps = st.session_state.openai_client.beta.threads.runs.steps.list(
            thread_id=thread.id,
            run_id=run.id
        )
        for step in run_steps.data:
            #print(step)
            if step.type == 'tool_calls':
                print(f"\n--------------------\nTool {step.type} invoked.\n--------------------\n")

            # If step involves code execution, print the code
            if step.type == 'code_interpreter':
                print(f"Python Code Executed: {step.step_details['code_interpreter']['input']}")

        if run_status.status == 'completed':
            # Retrieve all messages from the thread
            messages = st.session_state.openai_client.beta.threads.messages.list(
                thread_id=thread.id
            ) 

            # Print all messages from the thread
            for msg in messages.data:
                role = msg.role
                content = msg.content[0].text.value
                print(f"{role.capitalize()}: {content}")
            return messages
        elif run_status.status in ['queued', 'in_progress']:
            print(f'{run_status.status.capitalize()}... Please wait.')
            time.sleep(1.5)  # Wait before checking again
        elif run_status.status == 'requires_action':
            required_action = run_status.required_action
            if required_action.type == 'submit_tool_outputs':
                print(f"Requires tool outputs: {required_action}")                
                outputs = {}
                for tool_call in required_action.submit_tool_outputs.tool_calls:
                    if tool_call.function.name =="getListOfCompanies":
                        try:
                            args = json.loads(tool_call.function.arguments)
                            res = ''
                            if 'query' in args:
                                print(f"Processing tool_call {tool_call.id}. Calling 'getListOfCompanies with args: {args}" )
                                search_filters = json.loads(args['filters']) if 'filters' in args else {}
                                res = getListOfCompanies(args['query'], search_filters)
                            
                                
                            outputs[tool_call.id] = res
                        except Exception as e:
                            print(f"Error calling tools, {str(e)}")
                            traceback.print_exc()

                tool_outputs=[{"tool_call_id": k, "output": v} for (k,v) in outputs.items()]
                print(f"Finished tools calling: {str(tool_outputs)[:400]}")
                run = st.session_state.openai_client.beta.threads.runs.submit_tool_outputs(
                    thread_id=thread.id,
                    run_id=run.id,
                    tool_outputs=tool_outputs,
                )
                print(f"Required action {run_status.required_action}")
                #return run_status
            else:
                report_error(f"Unknown required action type: {required_action}")
                break
        else:
            report_error(f"Unhandled Run status: {run_status.status}\n\nError: {run_status.last_error}\n")
            break
    if time.time()-started>timeout:
        report_error(f"Wait for response timeout after {timeout}")
    report_error(f"Flow not completed")
    messages = st.session_state.openai_client.beta.threads.messages.list(
        thread_id=thread.id
    )             
    return messages
            

def call_assistant(query, engine="gpt-3.5-turbo"):  #, temp=0, top_p=1.0, max_tokens=4048):
    #Prevent re sending the last message over and over
    print(f"Last query {st.session_state.last_user_query}, current query {query}")
    if st.session_state.last_user_query == query:
        report_error(f"That query '{query}' was just sent. We don't send the same query twice in a row. ")
        return st.session_state.messages
    try:
        thread = st.session_state.assistant_thread
        assistant_id = st.session_state.assistant_id
        message = st.session_state.openai_client.beta.threads.messages.create(
            thread.id,
            role="user",
            content=query,
        )            
        run = st.session_state.openai_client.beta.threads.runs.create(
            thread_id=thread.id,
            assistant_id=assistant_id,
        )
        messages = wait_for_response(thread, run)            

        print(f"====================\nOpen AI response\n {str(messages)[:1000]}\n====================\n")
        return messages
        # text = ""
        # for message in messages:
        #     print(message)
        #     text = text + "\n" + message.content[0].text.value
        # return text
    except Exception as e:
    #except openai.error.OpenAIError as e:
        print(f"An error occurred: {str(e)}")
        

def call_openai(prompt, engine="gpt-3.5-turbo", temp=0, top_p=1.0, max_tokens=4048):
    if st.session_state.report_type=="assistant":
        raise Exception("use call_assistant instead of call_openai")        
    else:        
        try:
            response = st.session_state.openai_client.chat.completions.create(
                model=engine,
                messages=st.session_state.messages + [{"role": "user", "content": prompt}],
                temperature=temp,
                seed = SEED,
                max_tokens=max_tokens
            )
            print(f"====================\nOpen AI response\n {response}\n====================\n")
            text = response.choices[0].message.content.strip()
            return text
        except Exception as e:
        #except openai.error.OpenAIError as e:
            print(f"An error occurred: {str(e)}")
    return "Failed to generate a response."


def send_message(role, content):
    message = st.session_state.openai_client.beta.threads.messages.create(
        thread_id=st.session_state.assistant_thread.id,
        role=role,
        content=content
    )

def start_conversation():
    st.session_state.assistant_thread = st.session_state.openai_client.beta.threads.create()    
    

def run_assistant():
    run = st.session_state.openai_client.beta.threads.runs.create(
        thread_id=st.session_state.assistant_thread.id,
        assistant_id=st.session_state.assistant.id,
    )
    while run.status == "queued" or run.status == "in_progress":
        run = st.session_state.openai_client.beta.threads.runs.retrieve(
            thread_id=st.session_state.assistant_thread.id,
            run_id=run.id,
        )
        time.sleep(0.5)
    return run