# %%
from pprint import pprint
import os
import httpx

# from pydantic_settings import BaseSettings, SettingsConfigDict
# from pydantic import SecretStr
#
# model_config = SettingsConfigDict(env_prefix="EMM_RETRIEVERS_", env_file="/eos/jeodpp/home/users/consose/PycharmProjects/disasterStories-prj/.env")
#
# class RetrieverSettings(BaseSettings):
#     api_base: str
#     api_key: SecretStr
#
#     class Config:
#         config_dict = model_config
#
# settings = RetrieverSettings()
# print(settings.api_base)
#print(settings.api_key.get_secret_value())


from client_v1.formatting_utils import fixed_width_wrap, format_docs
from client_v1.settings import EmmRetrieversSettings



# %%
settings = EmmRetrieversSettings()

settings.API_BASE

# the test index configuration
TEST_INDEX = "mine_e_emb-rag_live_test_001"
INDEX_MIN = "2024-09-14"
INDEX_MAX = "2024-09-28"

# instantiate an httpx client once with base url and auth
client = httpx.Client(
    base_url=settings.API_BASE,
    headers={"Authorization": f"Bearer {settings.API_KEY.get_secret_value()}"},
)


# %%
# get your auth info
client.get("/_cat/token").json()

EXAMPLE_QUESTION = "What natural disasters are currently occuring?"

# %%
r = client.post(
    "/r/rag-minimal/query",
    params={"cluster_name": settings.DEFAULT_CLUSTER, "index": TEST_INDEX},
    json={
        "query": EXAMPLE_QUESTION,
        "spec": {"search_k": 20},
        "filter": {
            "max_chunk_no": 1,
            "min_chars": 200,
            "start_dt": "2024-09-19",
            "end_dt": "2024-09-20",
        },
    },
)

r.raise_for_status()

search_resp = r.json()

documents = search_resp["documents"]
print(len(documents))


titles = [d["metadata"]["title"] for d in documents]

print("\n".join([f"- {title}" for title in titles]))

# %%
# full chunk formatting:

print(format_docs(documents, fixed_width=True))

# %%
# Using the gpt@jrc language models


from client_v1.jrc_openai import JRCChatOpenAI

llm_model = JRCChatOpenAI(model="llama-3.1-70b-instruct", openai_api_key=settings.OPENAI_API_KEY.get_secret_value(), openai_api_base=settings.OPENAI_API_BASE_URL)

resp = llm_model.invoke("What is the JRC?")
print(resp.content)
pprint(resp.response_metadata)

# %%

from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough
from langchain_core.output_parsers import StrOutputParser


system_prompt = (
    "You are an assistant for question-answering tasks. "
    "Use the following pieces of retrieved context to answer "
    "the question. If you don't know the answer, say that you "
    "don't know."
    "\n\n"
    "{context}"
)

prompt = ChatPromptTemplate.from_messages(
    [
        ("system", system_prompt),
        ("human", "{input}"),
    ]
)

rag_chain = prompt | llm_model

# Add the API key to the LLM model
#llm_model.api_key = settings.OPENAI_API_KEY.get_secret_value()

r = rag_chain.invoke({"input": EXAMPLE_QUESTION, "context": format_docs(documents)})

print(fixed_width_wrap(r.content))
print("-" * 42)
pprint(r.response_metadata)

# %% [markdown]

# notes:
# - custom retriever class
# - multiquery retrieval https://python.langchain.com/docs/how_to/MultiQueryRetriever/
# - self query https://python.langchain.com/docs/how_to/self_query/


# %%
# using prompt hubs

import langchain.hub

if hasattr(settings, 'LANGCHAIN_API_KEY'):
    os.environ["LANGCHAIN_API_KEY"] = settings.LANGCHAIN_API_KEY.get_secret_value()

    rag_prompt = langchain.hub.pull("rlm/rag-prompt")
    print(
        fixed_width_wrap(
            rag_prompt.format(**{k: "{" + k + "}" for k in rag_prompt.input_variables})
        )
    )


# %%