File size: 3,662 Bytes
3724ac8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
# %%
from pprint import pprint
import os
import httpx

# from pydantic_settings import BaseSettings, SettingsConfigDict
# from pydantic import SecretStr
#
# model_config = SettingsConfigDict(env_prefix="EMM_RETRIEVERS_", env_file="/eos/jeodpp/home/users/consose/PycharmProjects/disasterStories-prj/.env")
#
# class RetrieverSettings(BaseSettings):
#     api_base: str
#     api_key: SecretStr
#
#     class Config:
#         config_dict = model_config
#
# settings = RetrieverSettings()
# print(settings.api_base)
#print(settings.api_key.get_secret_value())


from client_v1.formatting_utils import fixed_width_wrap, format_docs
from client_v1.settings import EmmRetrieversSettings



# %%
settings = EmmRetrieversSettings()

settings.API_BASE

# the test index configuration
TEST_INDEX = "mine_e_emb-rag_live_test_001"
INDEX_MIN = "2024-09-14"
INDEX_MAX = "2024-09-28"

# instantiate an httpx client once with base url and auth
client = httpx.Client(
    base_url=settings.API_BASE,
    headers={"Authorization": f"Bearer {settings.API_KEY.get_secret_value()}"},
)


# %%
# get your auth info
client.get("/_cat/token").json()

EXAMPLE_QUESTION = "What natural disasters are currently occuring?"

# %%
r = client.post(
    "/r/rag-minimal/query",
    params={"cluster_name": settings.DEFAULT_CLUSTER, "index": TEST_INDEX},
    json={
        "query": EXAMPLE_QUESTION,
        "spec": {"search_k": 20},
        "filter": {
            "max_chunk_no": 1,
            "min_chars": 200,
            "start_dt": "2024-09-19",
            "end_dt": "2024-09-20",
        },
    },
)

r.raise_for_status()

search_resp = r.json()

documents = search_resp["documents"]
print(len(documents))


titles = [d["metadata"]["title"] for d in documents]

print("\n".join([f"- {title}" for title in titles]))

# %%
# full chunk formatting:

print(format_docs(documents, fixed_width=True))

# %%
# Using the gpt@jrc language models


from client_v1.jrc_openai import JRCChatOpenAI

llm_model = JRCChatOpenAI(model="llama-3.1-70b-instruct", openai_api_key=settings.OPENAI_API_KEY.get_secret_value(), openai_api_base=settings.OPENAI_API_BASE_URL)

resp = llm_model.invoke("What is the JRC?")
print(resp.content)
pprint(resp.response_metadata)

# %%

from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough
from langchain_core.output_parsers import StrOutputParser


system_prompt = (
    "You are an assistant for question-answering tasks. "
    "Use the following pieces of retrieved context to answer "
    "the question. If you don't know the answer, say that you "
    "don't know."
    "\n\n"
    "{context}"
)

prompt = ChatPromptTemplate.from_messages(
    [
        ("system", system_prompt),
        ("human", "{input}"),
    ]
)

rag_chain = prompt | llm_model

# Add the API key to the LLM model
#llm_model.api_key = settings.OPENAI_API_KEY.get_secret_value()

r = rag_chain.invoke({"input": EXAMPLE_QUESTION, "context": format_docs(documents)})

print(fixed_width_wrap(r.content))
print("-" * 42)
pprint(r.response_metadata)

# %% [markdown]

# notes:
# - custom retriever class
# - multiquery retrieval https://python.langchain.com/docs/how_to/MultiQueryRetriever/
# - self query https://python.langchain.com/docs/how_to/self_query/


# %%
# using prompt hubs

import langchain.hub

if hasattr(settings, 'LANGCHAIN_API_KEY'):
    os.environ["LANGCHAIN_API_KEY"] = settings.LANGCHAIN_API_KEY.get_secret_value()

    rag_prompt = langchain.hub.pull("rlm/rag-prompt")
    print(
        fixed_width_wrap(
            rag_prompt.format(**{k: "{" + k + "}" for k in rag_prompt.input_variables})
        )
    )


# %%