Spaces:
Sleeping
Sleeping
"""Implements the document grader chain""" | |
from langchain_core.pydantic_v1 import BaseModel, Field | |
from langchain_core.utils.function_calling import convert_to_openai_tool | |
from langchain.output_parsers.openai_tools import PydanticToolsParser | |
from langchain_openai import ChatOpenAI | |
from langchain_core.prompts import PromptTemplate | |
from textwrap import dedent | |
import os | |
from dotenv import load_dotenv | |
load_dotenv() | |
# Data model | |
class grade(BaseModel): | |
"""Binary score for relevance check.""" | |
binary_score: str = Field(description="Relevance score 'yes' or 'no'") | |
class DocumentGrader: | |
"""Implements the document grader chain""" | |
_GRADER_PROMPT_TEMPLATE = """ | |
You are a grader assessing relevance of a retrieved document to a user question. \n | |
Retrieved document: \n\n {context} \n\n | |
User Question: {question} \n | |
When assessing the relevance of a retrieved document to a user question, consider whether the document can provide a complete answer to the question posed. A document is considered relevant only if it contains all the necessary information to fully answer the user's inquiry without requiring additional context or assumptions. | |
Give a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question. | |
Do not return anything other than a 'yes' or 'no'. | |
""" | |
_GRADER_PROMPT = PromptTemplate(template=dedent(_GRADER_PROMPT_TEMPLATE), input_variables=["context", "question"]) | |
def __init__(self): | |
# seperate the model wrapper instance for the binded tool | |
llm = ChatOpenAI(temperature=0, model=os.environ["OPENAI_MODEL"]) | |
grade_tool_oai = convert_to_openai_tool(grade) | |
# LLM with tool and enforce invocation | |
llm_with_tool = llm.bind( | |
tools=[grade_tool_oai], | |
tool_choice={"type": "function", "function": {"name": "grade"}}, | |
) | |
# Parser | |
parser_tool = PydanticToolsParser(tools=[grade]) | |
self._grader_chain = self._GRADER_PROMPT | llm_with_tool | parser_tool | |
def run(self, question, context): | |
"""Returns the response from the document grader""" | |
return self._grader_chain.invoke({"context": context, "question": question}) | |