hello-sajal / src /chains /document_grader.py
sajal2692's picture
add hfspace files
274be20
"""Implements the document grader chain"""
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.utils.function_calling import convert_to_openai_tool
from langchain.output_parsers.openai_tools import PydanticToolsParser
from langchain_openai import ChatOpenAI
from langchain_core.prompts import PromptTemplate
from textwrap import dedent
import os
from dotenv import load_dotenv
load_dotenv()
# Data model
class grade(BaseModel):
"""Binary score for relevance check."""
binary_score: str = Field(description="Relevance score 'yes' or 'no'")
class DocumentGrader:
"""Implements the document grader chain"""
_GRADER_PROMPT_TEMPLATE = """
You are a grader assessing relevance of a retrieved document to a user question. \n
Retrieved document: \n\n {context} \n\n
User Question: {question} \n
When assessing the relevance of a retrieved document to a user question, consider whether the document can provide a complete answer to the question posed. A document is considered relevant only if it contains all the necessary information to fully answer the user's inquiry without requiring additional context or assumptions.
Give a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question.
Do not return anything other than a 'yes' or 'no'.
"""
_GRADER_PROMPT = PromptTemplate(template=dedent(_GRADER_PROMPT_TEMPLATE), input_variables=["context", "question"])
def __init__(self):
# seperate the model wrapper instance for the binded tool
llm = ChatOpenAI(temperature=0, model=os.environ["OPENAI_MODEL"])
grade_tool_oai = convert_to_openai_tool(grade)
# LLM with tool and enforce invocation
llm_with_tool = llm.bind(
tools=[grade_tool_oai],
tool_choice={"type": "function", "function": {"name": "grade"}},
)
# Parser
parser_tool = PydanticToolsParser(tools=[grade])
self._grader_chain = self._GRADER_PROMPT | llm_with_tool | parser_tool
def run(self, question, context):
"""Returns the response from the document grader"""
return self._grader_chain.invoke({"context": context, "question": question})