jattokatarratto's picture
Upload folder using huggingface_hub
3724ac8 verified
from typing import Any, Coroutine
import httpx
from langchain_core.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain_core.documents import Document
from langchain_core.retrievers import BaseRetriever
from pydantic import Field, PrivateAttr, model_validator
from .settings import EmmRetrieversSettings
def as_lc_docs(dicts: list[dict]) -> list[Document]:
return [
Document(page_content=d["page_content"], metadata=d["metadata"]) for d in dicts
]
# the simple retriver is built with fixed spec/filter/params/route config
# and the can be used many times with different queries.
# Note these are cheap to construct.
class EmmRetrieverV1(BaseRetriever):
settings: EmmRetrieversSettings
spec: dict
filter: dict | None = None
params: dict = Field(default_factory=dict)
route: str = "/r/rag-minimal/query"
add_ref_key: bool = True
_client: httpx.Client = PrivateAttr()
_aclient: httpx.AsyncClient = PrivateAttr()
# ------- interface impl:
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> list[Document]:
r = self._client.post(**self.search_post_kwargs(query))
if r.status_code == 422:
print("ERROR:\n", r.json())
r.raise_for_status()
resp = r.json()
return self._as_lc_docs(resp["documents"])
async def _aget_relevant_documents(
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
) -> Coroutine[Any, Any, list[Document]]:
r = await self._aclient.post(**self.search_post_kwargs(query))
if r.status_code == 422:
print("ERROR:\n", r.json())
r.raise_for_status()
resp = r.json()
return self._as_lc_docs(resp["documents"])
# ---------
@model_validator(mode="after")
def create_clients(self):
_auth_headers = {
"Authorization": f"Bearer {self.settings.API_KEY.get_secret_value()}"
}
kwargs = dict(
base_url=self.settings.API_BASE,
headers=_auth_headers,
timeout=self.settings.DEFAULT_TIMEOUT,
)
self._client = httpx.Client(**kwargs)
self._aclient = httpx.AsyncClient(**kwargs)
return self
@model_validator(mode="after")
def apply_default_params(self):
self.params = {
**{
"cluster_name": self.settings.DEFAULT_CLUSTER,
"index": self.settings.DEFAULT_INDEX,
},
**(self.params or {}),
}
return self
def _as_lc_docs(self, dicts: list[dict]) -> list[Document]:
docs = as_lc_docs(dicts)
if self.add_ref_key:
for i, d in enumerate(docs):
d.metadata["ref_key"] = i
return docs
def search_post_kwargs(self, query: str):
return dict(
url=self.route,
params=self.params,
json={"query": query, "spec": self.spec, "filter": self.filter},
)