# from fastapi import FastAPI
# from fastapi.middleware.cors import CORSMiddleware
from openai import OpenAI
from google import genai
from crawler import extract_data
import time
import os
from dotenv import load_dotenv
import gradio as gr
# import multiprocessing
from together import Together

load_dotenv("../.env")
# print("Environment variables:", os.environ)


together_client = Together(
    api_key=os.getenv("TOGETHER_API_KEY"),
)

gemini_client = genai.Client(api_key=os.getenv("GEMINI_API_KEY"))
genai_model = "gemini-2.0-flash-exp"

perplexity_client = OpenAI(api_key=os.getenv("PERPLEXITY_API_KEY"), base_url="https://api.perplexity.ai")
gpt_client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))



def get_answers( query: str ):
    context = extract_data(query, 1)
    return context

# with torch.no_grad():
#     model = AutoModel.from_pretrained('BM-K/KoSimCSE-roberta')
#     tokenizer = AutoTokenizer.from_pretrained('BM-K/KoSimCSE-roberta', TOKENIZERS_PARALLELISM=True)

# def cal_score(input_data):
#     # Initialize model and tokenizer inside the function
#     with torch.no_grad():
#         inputs = tokenizer(input_data, padding=True, truncation=True, return_tensors="pt")
#         outputs = model.get_input_embeddings(inputs["input_ids"])
        
#         a, b = outputs[0], outputs[1]  # Adjust based on your model's output structure

#         # Normalize the tensors
#         a_norm = a / a.norm(dim=1)[:, None]
#         b_norm = b / b.norm(dim=1)[:, None]

#         print(a.shape, b.shape)
        
#         # Return the similarity score
#         # return torch.mm(a_norm, b_norm.transpose(0, 1)) * 100
#         a_norm = a_norm.reshape(1, -1)
#         b_norm = b_norm.reshape(1, -1)
#         similarity_score = cosine_similarity(a_norm, b_norm)

#         # Return the similarity score (assuming you want the average of the similarities across the tokens)
#         return similarity_score # Scalar value



# def get_match_scores( message: str, query: str, answers: list[dict[str, object]] ):
#     start = time.time()
#     max_processes = 4
#     with multiprocessing.Pool(processes=max_processes) as pool:
#         scores = pool.map(cal_score, [[answer['questionDetails'], message] for answer in answers])
#     print(f"Time taken to compare: {time.time() - start} seconds")
#     print("Scores: ", scores)
#     return scores

def get_naver_answers( message: str ):
    print(">>> Starting naver extraction...")
    print("Question: ", message)
    naver_start_time = time.time()
    response = gemini_client.models.generate_content(
        model = genai_model,
        contents=f"{message}\n 위의 내용을 짧은 제목으로 요약합니다. 제목만 보여주세요. 대답하지 마세요. 한국어로만 답변해주세요!!!",
    )
    query = response.text
    print( "Query: ", query)

    context = get_answers( query )
        
    sorted_answers = ['. '.join(answer['answers']) for answer in context]
    naver_end_time = time.time()
    print(f"Time taken to extract from Naver: { naver_end_time - naver_start_time } seconds")
    document = '\n'.join(sorted_answers)
    return document, naver_end_time - naver_start_time

def get_qwen_big_answer( message: str ):
    print(">>> Starting Qwen 72B extraction...")
    qwen_start_time = time.time()
    response = together_client.chat.completions.create(
        model="Qwen/Qwen2.5-72B-Instruct-Turbo",
        messages=[
            {"role": "system", "content": "You are a helpful question-answer, CONCISE conversation assistant that answers in Korean."},
            {"role": "user", "content": message}
        ]
    )

    qwen_end_time = time.time()
    print(f"Time taken to extract from Qwen: { qwen_end_time - qwen_start_time } seconds")
    return response.choices[0].message.content, qwen_end_time - qwen_start_time

def get_qwen_small_answer( message: str ):
    print(">>> Starting Qwen 7B extraction...")
    qwen_start_time = time.time()
    response = together_client.chat.completions.create(
        model="Qwen/Qwen2.5-7B-Instruct-Turbo",
        messages=[
            {"role": "system", "content": "You are a helpful question-answer, conversation assistant that answers in Korean. Your responses should sound human-like."},
            {"role": "user", "content": message}
        ],
        max_tokens = None
        #TODO: Change the messages option
    )
    qwen_end_time = time.time()
    print(f"Time taken to extract from Qwen: { qwen_end_time - qwen_start_time } seconds")
    return response.choices[0].message.content, qwen_end_time - qwen_start_time

def get_llama_small_answer( message: str ):
    print(">>> Starting Llama 3.1 8B extraction...")
    llama_start_time = time.time()
    response = together_client.chat.completions.create(
        model="meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo",
        messages=[
            {"role": "system", "content": "You are an artificial intelligence assistant and you need to engage in a helpful, CONCISE, polite question-answer conversation with a user."},
            {
                "role": "user",
                "content": message
            }
        ]
    )
    llama_end_time = time.time()
    print(f"Time taken to extract from Llama: { llama_end_time - llama_start_time } seconds")
    return response.choices[0].message.content, llama_end_time - llama_start_time

def get_llama_big_answer( message: str ):
    print(">>> Starting Llama 3.1 70B extraction...")
    llama_start_time = time.time()
    response = together_client.chat.completions.create(
        model="meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
        messages=[
            {"role": "system", "content": "You are an artificial intelligence assistant and you need to engage in a helpful, CONCISE, polite question-answer conversation with a user."},
            {
                "role": "user",
                "content": message
            }
        ]
    )
    llama_end_time = time.time()
    print(f"Time taken to extract from Llama: { llama_end_time - llama_start_time } seconds")
    return response.choices[0].message.content, llama_end_time - llama_start_time


def get_gemini_answer( message: str ):
    print(">>> Starting gemini extraction...")
    gemini_start_time = time.time()
    response = gemini_client.models.generate_content(
        model = genai_model,
        contents=message,
    )
    gemini_end_time = time.time()
    print(f"Time taken to extract from Gemini: { gemini_end_time - gemini_start_time } seconds")
    return response.candidates[0].content, gemini_end_time - gemini_start_time

# def get_perplexity_answer( message: str ):
#     print(">>> Starting perplexity extraction...")
#     perplexity_start_time = time.time()
#     messages = [
#         {
#             "role": "system",
#             "content": (
#                 "You are an artificial intelligence assistant and you need to "
#                 "engage in a helpful, CONCISE, polite question-answer conversation with a user."
#             ),
#         },
#         {   
#             "role": "user",
#             "content": (
#                 message
#             ),
#         },
#     ]
#     response = perplexity_client.chat.completions.create(
#         model="llama-3.1-sonar-small-128k-online",
#         messages=messages
#     )
#     perplexity_end_time = time.time()
#     print(f"Time taken to extract from Perplexity: { perplexity_end_time - perplexity_start_time } seconds")
#     return response.choices[0].message.content, perplexity_end_time - perplexity_start_time

def get_gpt_answer( message: str ):
    print(">>> Starting GPT extraction...")
    gpt_start_time = time.time()
    completion = gpt_client.chat.completions.create(
        model="gpt-4o-mini",
        messages=[
            {"role": "system", "content": "You are a helpful assistant that gives short answers and nothing extra."},
            {
                "role": "user",
                "content": message
            }
        ]
    )
    gpt_end_time = time.time()
    print(f"Time taken to extract from GPT: { gpt_end_time - gpt_start_time } seconds")
    return completion.choices[0].message.content, gpt_end_time - gpt_start_time

def compare_answers(message: str):
    methods = [
        ("Qwen Big (72B)", get_qwen_big_answer),
        ("Qwen Small (7B)", get_qwen_small_answer),
        ("Llama Small (8B)", get_llama_small_answer),
        ("Llama Big (70B)", get_llama_big_answer),
        ("Gemini-2.0-Flash", get_gemini_answer),
        # ("Perplexity", get_perplexity_answer),
        ("GPT (4o-mini)", get_gpt_answer)
    ]

    results = []

    naver_docs, naver_time_taken = get_naver_answers( message )
    content = f'아래 문서를 바탕으로 질문에 답하세요. 답변은 한국어로만 해주세요 \n 질문 {message}\n'
    content += naver_docs
    print("Starting the comparison between summarizers...")
    for method_name, method in methods:
        answer, time_taken = method(content)
        results.append({
            "Method": f"Naver + ({method_name})",
            "Question": message,
            "Answer": answer,
            "Time Taken": naver_time_taken + time_taken
        })

    print("Starting the comparison between extractors/summarizers...")
    for method_name, method in methods:
        additional_docs, time_taken = method(message)
        results.append({
            "Method": method_name,
            "Question": message,
            "Answer": additional_docs,
            "Time Taken": time_taken
        })
        content += f'\n{additional_docs}'
        time_taken += naver_time_taken
        for summarizer_name, summarizer in methods:
            answer, answer_time = summarizer(content)
            results.append({
                "Method": f"Naver + {method_name} + ({summarizer_name})",
                "Question": message,
                "Answer": answer,
                "Time Taken": time_taken + answer_time
            })
    return results

def chatFunction( message, history ):
    content = f'아래 문서를 바탕으로 질문에 답하세요. 답변에서 질문을 따라 출력 하지 마세요. 답변은 한국어로만 해주세요! 찾은 Naver 문서와 다른 문서에서 답변이 없는 내용은 절대 출력하지 마세요. 친절하고 인간답게 말하세요. \n 질문: {message}\n 문서: '
    naver_docs, naver_time_taken = get_naver_answers( message )

    if len(naver_docs) > 55000:
        overlap = 200
        answers = []
        split_len = len(naver_docs) // ( ( len(naver_docs) - 55000 ) // 55000 + 2 ) + 1
        for i in range( len(naver_docs), split_len ):
            if i == 0:
                split = naver_docs[:split_len]
            else:
                split = naver_docs[i * split_len - overlap: (i + 1) * split_len]
            answer, _ = get_qwen_small_answer(f"Summarize important points in a paragraph, given the information below, using only Korean language. Give me only the summary!!! \n {split}")
            answers.append(answer)
        naver_docs = '\n'.join(answers)

    start_time = time.time()
    content += "\n Naver 문서: " + naver_docs

    completion = gpt_client.chat.completions.create(
        model="gpt-4o-mini",
        messages=[
            {"role": "system", "content": "You are a helpful assistant that gives detailed answers only in korean."},
            {
                "role": "user",
                "content": message
            }
        ]
    )
    gpt_resp = completion.choices[0].message.content
    content += "\n 다른 문서: " + gpt_resp

    # content += "\n" + gpt_resp

    answer, _ = get_qwen_small_answer(content)

    print("-"*70)
    print("Question: ", message)
    print("Answer: ", answer)
    time_taken = time.time() - start_time
    print("Time taken to summarize: ", time_taken)
    return answer
    

if __name__ == "__main__":
    # multiprocessing.set_start_method("fork", force=True)
    # if multiprocessing.get_start_method(allow_none=True) is None:
    #     multiprocessing.set_start_method("fork")
    with gr.ChatInterface( fn=chatFunction, type="messages" ) as demo: pass
    demo.launch(share=True)
    # with open("test_questions.txt", "r") as f:
    #     if os.path.exists("comparison_results.csv"):
    #         if input("Do you want to delete the former results? (y/n): ") == "y":
    #             os.remove("comparison_results.csv")
    #     questions = f.readlines()
    #     print(questions)
    #     for idx, question in enumerate(questions):
    #         print(" -> Starting the question number: ", idx)
    #         results = compare_answers(question)
    #         df = pd.DataFrame(results)
    #         df.to_csv("comparison_results.csv", mode='a', index=False)