"""Conversational QA Chain"""
from __future__ import annotations
import os
import re
import time
import logging
from fastapi import FastAPI
from pydantic import BaseModel
from langchain.chat_models import ChatOpenAI, ChatAnthropic
from langchain.memory import ConversationTokenBufferMemory
from convo_qa_chain import ConvoRetrievalChain

from toolkit.together_api_llm import TogetherLLM
from toolkit.retrivers import MyRetriever
from toolkit.local_llm import load_local_llm
from toolkit.utils import (
    Config,
    choose_embeddings,
    load_embedding,
    load_pickle,
    check_device,
)
app =FastAPI()

# Load the config file
configs = Config("configparser.ini")
logger = logging.getLogger(__name__)

os.environ["OPENAI_API_KEY"] = configs.openai_api_key
os.environ["ANTHROPIC_API_KEY"] = configs.anthropic_api_key

embedding = choose_embeddings(configs.embedding_name)
db_store_path = configs.db_dir


# get models
def get_llm(llm_name: str, temperature: float, max_tokens: int):
    """Get the LLM model from the model name."""

    if not os.path.exists(configs.local_model_dir):
        os.makedirs(configs.local_model_dir)

    splits = llm_name.split("|")  # [provider, model_name, model_file]

    if "openai" in splits[0].lower():
        llm_model = ChatOpenAI(
            model=splits[1],
            temperature=temperature,
            max_tokens=max_tokens,
        )

    elif "anthropic" in splits[0].lower():
        llm_model = ChatAnthropic(
            model=splits[1],
            temperature=temperature,
            max_tokens_to_sample=max_tokens,
        )

    elif "together" in splits[0].lower():
        llm_model = TogetherLLM(
            model=splits[1],
            temperature=temperature,
            max_tokens=max_tokens,
        )
    elif "huggingface" in splits[0].lower():
        llm_model = load_local_llm(
            model_id=splits[1],
            model_basename=splits[-1],
            temperature=temperature,
            max_tokens=max_tokens,
            device_type=check_device(),
        )
    else:
        raise ValueError("Invalid Model Name")

    return llm_model


llm = get_llm(configs.model_name, configs.temperature, configs.max_llm_generation)


# load retrieval database
db_embedding_chunks_small = load_embedding(
    store_name=configs.embedding_name,
    embedding=embedding,
    suffix="chunks_small",
    path=db_store_path,
)
db_embedding_chunks_medium = load_embedding(
    store_name=configs.embedding_name,
    embedding=embedding,
    suffix="chunks_medium",
    path=db_store_path,
)

db_docs_chunks_small = load_pickle(
    prefix="docs_pickle", suffix="chunks_small", path=db_store_path
)
db_docs_chunks_medium = load_pickle(
    prefix="docs_pickle", suffix="chunks_medium", path=db_store_path
)
file_names = load_pickle(prefix="file", suffix="names", path=db_store_path)


# Initialize the retriever
my_retriever = MyRetriever(
    llm=llm,
    embedding_chunks_small=db_embedding_chunks_small,
    embedding_chunks_medium=db_embedding_chunks_medium,
    docs_chunks_small=db_docs_chunks_small,
    docs_chunks_medium=db_docs_chunks_medium,
    first_retrieval_k=configs.first_retrieval_k,
    second_retrieval_k=configs.second_retrieval_k,
    num_windows=configs.num_windows,
    retriever_weights=configs.retriever_weights,
)


# Initialize the memory
memory = ConversationTokenBufferMemory(
    llm=llm,
    memory_key="chat_history",
    input_key="question",
    output_key="answer",
    return_messages=True,
    max_token_limit=configs.max_chat_history,
)


# Initialize the QA chain
qa = ConvoRetrievalChain.from_llm(
    llm,
    my_retriever,
    file_names=file_names,
    memory=memory,
    return_source_documents=False,
    return_generated_question=False,
)




class Question(BaseModel):
    question: str

@app.get("/chat/")
def chat_with(str1: str):
    resp = qa({"question": str1})
    answer = resp.get('answer', '')
    return {'message': answer}

  
# @app.get("/")  

# def chat_with(str1):
#     resp = qa({"question": str1})
#     return {'message':resp}

    

    
'''
if __name__ == "__main__":
    while True:
        user_input = input("Human: ")
        start_time = time.time()
        user_input_ = re.sub(r"^Human: ", "", user_input)
        print("*" * 6)
        resp = qa({"question": user_input_})
        print()
        print(f"AI:{resp['answer']}")
        print(f"Time used: {time.time() - start_time}")
        print("-" * 60)
'''