from typing import Any, Coroutine import httpx from langchain_core.callbacks.manager import ( AsyncCallbackManagerForRetrieverRun, CallbackManagerForRetrieverRun, ) from langchain_core.documents import Document from langchain_core.retrievers import BaseRetriever from pydantic import Field, PrivateAttr, model_validator from .settings import EmmRetrieversSettings def as_lc_docs(dicts: list[dict]) -> list[Document]: return [ Document(page_content=d["page_content"], metadata=d["metadata"]) for d in dicts ] # the simple retriver is built with fixed spec/filter/params/route config # and the can be used many times with different queries. # Note these are cheap to construct. class EmmRetrieverV1(BaseRetriever): settings: EmmRetrieversSettings spec: dict filter: dict | None = None params: dict = Field(default_factory=dict) route: str = "/r/rag-minimal/query" add_ref_key: bool = True _client: httpx.Client = PrivateAttr() _aclient: httpx.AsyncClient = PrivateAttr() # ------- interface impl: def _get_relevant_documents( self, query: str, *, run_manager: CallbackManagerForRetrieverRun ) -> list[Document]: r = self._client.post(**self.search_post_kwargs(query)) if r.status_code == 422: print("ERROR:\n", r.json()) r.raise_for_status() resp = r.json() return self._as_lc_docs(resp["documents"]) async def _aget_relevant_documents( self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun ) -> Coroutine[Any, Any, list[Document]]: r = await self._aclient.post(**self.search_post_kwargs(query)) if r.status_code == 422: print("ERROR:\n", r.json()) r.raise_for_status() resp = r.json() return self._as_lc_docs(resp["documents"]) # --------- @model_validator(mode="after") def create_clients(self): _auth_headers = { "Authorization": f"Bearer {self.settings.API_KEY.get_secret_value()}" } kwargs = dict( base_url=self.settings.API_BASE, headers=_auth_headers, timeout=self.settings.DEFAULT_TIMEOUT, ) self._client = httpx.Client(**kwargs) self._aclient = httpx.AsyncClient(**kwargs) return self @model_validator(mode="after") def apply_default_params(self): self.params = { **{ "cluster_name": self.settings.DEFAULT_CLUSTER, "index": self.settings.DEFAULT_INDEX, }, **(self.params or {}), } return self def _as_lc_docs(self, dicts: list[dict]) -> list[Document]: docs = as_lc_docs(dicts) if self.add_ref_key: for i, d in enumerate(docs): d.metadata["ref_key"] = i return docs def search_post_kwargs(self, query: str): return dict( url=self.route, params=self.params, json={"query": query, "spec": self.spec, "filter": self.filter}, )