Spaces:
Running
Running
from typing import Any, Coroutine | |
import httpx | |
from langchain_core.callbacks.manager import ( | |
AsyncCallbackManagerForRetrieverRun, | |
CallbackManagerForRetrieverRun, | |
) | |
from langchain_core.documents import Document | |
from langchain_core.retrievers import BaseRetriever | |
from pydantic import Field, PrivateAttr, model_validator | |
from .settings import EmmRetrieversSettings | |
def as_lc_docs(dicts: list[dict]) -> list[Document]: | |
return [ | |
Document(page_content=d["page_content"], metadata=d["metadata"]) for d in dicts | |
] | |
# the simple retriver is built with fixed spec/filter/params/route config | |
# and the can be used many times with different queries. | |
# Note these are cheap to construct. | |
class EmmRetrieverV1(BaseRetriever): | |
settings: EmmRetrieversSettings | |
spec: dict | |
filter: dict | None = None | |
params: dict = Field(default_factory=dict) | |
route: str = "/r/rag-minimal/query" | |
add_ref_key: bool = True | |
_client: httpx.Client = PrivateAttr() | |
_aclient: httpx.AsyncClient = PrivateAttr() | |
# ------- interface impl: | |
def _get_relevant_documents( | |
self, query: str, *, run_manager: CallbackManagerForRetrieverRun | |
) -> list[Document]: | |
r = self._client.post(**self.search_post_kwargs(query)) | |
if r.status_code == 422: | |
print("ERROR:\n", r.json()) | |
r.raise_for_status() | |
resp = r.json() | |
return self._as_lc_docs(resp["documents"]) | |
async def _aget_relevant_documents( | |
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun | |
) -> Coroutine[Any, Any, list[Document]]: | |
r = await self._aclient.post(**self.search_post_kwargs(query)) | |
if r.status_code == 422: | |
print("ERROR:\n", r.json()) | |
r.raise_for_status() | |
resp = r.json() | |
return self._as_lc_docs(resp["documents"]) | |
# --------- | |
def create_clients(self): | |
_auth_headers = { | |
"Authorization": f"Bearer {self.settings.API_KEY.get_secret_value()}" | |
} | |
kwargs = dict( | |
base_url=self.settings.API_BASE, | |
headers=_auth_headers, | |
timeout=self.settings.DEFAULT_TIMEOUT, | |
) | |
self._client = httpx.Client(**kwargs) | |
self._aclient = httpx.AsyncClient(**kwargs) | |
return self | |
def apply_default_params(self): | |
self.params = { | |
**{ | |
"cluster_name": self.settings.DEFAULT_CLUSTER, | |
"index": self.settings.DEFAULT_INDEX, | |
}, | |
**(self.params or {}), | |
} | |
return self | |
def _as_lc_docs(self, dicts: list[dict]) -> list[Document]: | |
docs = as_lc_docs(dicts) | |
if self.add_ref_key: | |
for i, d in enumerate(docs): | |
d.metadata["ref_key"] = i | |
return docs | |
def search_post_kwargs(self, query: str): | |
return dict( | |
url=self.route, | |
params=self.params, | |
json={"query": query, "spec": self.spec, "filter": self.filter}, | |
) | |