Spaces:
Running
Running
File size: 3,662 Bytes
3724ac8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 |
# %%
from pprint import pprint
import os
import httpx
# from pydantic_settings import BaseSettings, SettingsConfigDict
# from pydantic import SecretStr
#
# model_config = SettingsConfigDict(env_prefix="EMM_RETRIEVERS_", env_file="/eos/jeodpp/home/users/consose/PycharmProjects/disasterStories-prj/.env")
#
# class RetrieverSettings(BaseSettings):
# api_base: str
# api_key: SecretStr
#
# class Config:
# config_dict = model_config
#
# settings = RetrieverSettings()
# print(settings.api_base)
#print(settings.api_key.get_secret_value())
from client_v1.formatting_utils import fixed_width_wrap, format_docs
from client_v1.settings import EmmRetrieversSettings
# %%
settings = EmmRetrieversSettings()
settings.API_BASE
# the test index configuration
TEST_INDEX = "mine_e_emb-rag_live_test_001"
INDEX_MIN = "2024-09-14"
INDEX_MAX = "2024-09-28"
# instantiate an httpx client once with base url and auth
client = httpx.Client(
base_url=settings.API_BASE,
headers={"Authorization": f"Bearer {settings.API_KEY.get_secret_value()}"},
)
# %%
# get your auth info
client.get("/_cat/token").json()
EXAMPLE_QUESTION = "What natural disasters are currently occuring?"
# %%
r = client.post(
"/r/rag-minimal/query",
params={"cluster_name": settings.DEFAULT_CLUSTER, "index": TEST_INDEX},
json={
"query": EXAMPLE_QUESTION,
"spec": {"search_k": 20},
"filter": {
"max_chunk_no": 1,
"min_chars": 200,
"start_dt": "2024-09-19",
"end_dt": "2024-09-20",
},
},
)
r.raise_for_status()
search_resp = r.json()
documents = search_resp["documents"]
print(len(documents))
titles = [d["metadata"]["title"] for d in documents]
print("\n".join([f"- {title}" for title in titles]))
# %%
# full chunk formatting:
print(format_docs(documents, fixed_width=True))
# %%
# Using the gpt@jrc language models
from client_v1.jrc_openai import JRCChatOpenAI
llm_model = JRCChatOpenAI(model="llama-3.1-70b-instruct", openai_api_key=settings.OPENAI_API_KEY.get_secret_value(), openai_api_base=settings.OPENAI_API_BASE_URL)
resp = llm_model.invoke("What is the JRC?")
print(resp.content)
pprint(resp.response_metadata)
# %%
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough
from langchain_core.output_parsers import StrOutputParser
system_prompt = (
"You are an assistant for question-answering tasks. "
"Use the following pieces of retrieved context to answer "
"the question. If you don't know the answer, say that you "
"don't know."
"\n\n"
"{context}"
)
prompt = ChatPromptTemplate.from_messages(
[
("system", system_prompt),
("human", "{input}"),
]
)
rag_chain = prompt | llm_model
# Add the API key to the LLM model
#llm_model.api_key = settings.OPENAI_API_KEY.get_secret_value()
r = rag_chain.invoke({"input": EXAMPLE_QUESTION, "context": format_docs(documents)})
print(fixed_width_wrap(r.content))
print("-" * 42)
pprint(r.response_metadata)
# %% [markdown]
# notes:
# - custom retriever class
# - multiquery retrieval https://python.langchain.com/docs/how_to/MultiQueryRetriever/
# - self query https://python.langchain.com/docs/how_to/self_query/
# %%
# using prompt hubs
import langchain.hub
if hasattr(settings, 'LANGCHAIN_API_KEY'):
os.environ["LANGCHAIN_API_KEY"] = settings.LANGCHAIN_API_KEY.get_secret_value()
rag_prompt = langchain.hub.pull("rlm/rag-prompt")
print(
fixed_width_wrap(
rag_prompt.format(**{k: "{" + k + "}" for k in rag_prompt.input_variables})
)
)
# %%
|