AIE5-MidTerm / rag_graph.py
thomfoolery's picture
evaluate with fine-tuned embeddings
c1367c2
from dotenv import load_dotenv
from typing_extensions import List, TypedDict
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from langchain_core.documents import Document
from langchain_core.prompts import ChatPromptTemplate
from langchain_qdrant import QdrantVectorStore
from langchain_huggingface import HuggingFaceEmbeddings
from langgraph.graph import START, StateGraph
from langchain.prompts import ChatPromptTemplate
from langchain_community.document_loaders import DirectoryLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from qdrant_client.http.models import Distance, VectorParams
# Necessary for dependencies for DirectoryLoader
import nltk
nltk.download('punkt_tab')
nltk.download('averaged_perceptron_tagger_eng')
# Chunk configuration
CHUNK_SIZE = 1000
CHUNK_OVERLAP = CHUNK_SIZE // 2
# RAG prompt template
RAG_PROMPT = """\
You are a helpful assistant who helps Shopify merchants automate their businesses.
Your goal is to provide a helpful response to the merchant's question in straight forward, non technical language.
Try to be brief and to the point, but explain technical jargon.
You must only use the provided context, and cannot use your own knowledge.
### Question
{question}
### Context
{context}
"""
class RagGraph:
def __init__(self, qdrant_client, use_finetuned_embeddings=False):
self.llm = ChatOpenAI(model="gpt-4-turbo-preview", streaming=True)
self.collection_name = "rag_collection" if not use_finetuned_embeddings else "rag_collection_finetuned"
self.embeddings_model = OpenAIEmbeddings(model="text-embedding-3-small") \
if not use_finetuned_embeddings else HuggingFaceEmbeddings(model_name="thomfoolery/AIE5-MidTerm-finetuned-embeddings")
self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP)
self.rag_prompt = ChatPromptTemplate.from_template(RAG_PROMPT)
self.qdrant_client = qdrant_client
does_collection_exist = self.qdrant_client.collection_exists(collection_name=self.collection_name)
dimension_size = 1536 if not use_finetuned_embeddings else 1024
print(f"Collection {self.collection_name} exists: {does_collection_exist}")
if not does_collection_exist:
qdrant_client.create_collection(
collection_name=self.collection_name,
vectors_config=VectorParams(size=dimension_size, distance=Distance.COSINE),
)
self.vector_store = QdrantVectorStore(
client=qdrant_client,
collection_name=self.collection_name,
embedding=self.embeddings_model,
)
if not does_collection_exist:
loader = DirectoryLoader("data/scraped/clean", glob="*.txt")
documents = self.text_splitter.split_documents(loader.load())
self.vector_store.add_documents(documents=documents)
self.vector_db_retriever = self.vector_store.as_retriever(search_kwargs={"k": 5})
self.graph = None
self.create()
def create(self):
"""Create the RAG graph."""
class State(TypedDict):
"""State for the conversation."""
question: str
context: List[Document]
def retrieve(state):
question = state["question"]
context = self.vector_db_retriever.invoke(question)
return {"question": state["question"], "context": context}
async def stream(state):
"""LangGraph node that streams responses"""
question = state["question"]
context = "\n\n".join(doc.page_content for doc in state["context"])
messages = self.rag_prompt.format_messages(question=question, context=context)
async for chunk in self.llm.astream(messages):
yield {"content": chunk.content}
graph_builder = StateGraph(State).add_sequence([retrieve, stream])
graph_builder.add_edge(START, "retrieve")
self.graph = graph_builder.compile()
def run(self, question):
"""Invoke RAG response without streaming."""
chunks = self.vector_db_retriever.invoke(question)
context = "\n\n".join(doc.page_content for doc in chunks)
messages = self.rag_prompt.format_messages(question=question, context=context)
response = self.llm.invoke(messages)
return {
"response": response.content,
"context": chunks
}
async def stream(self, question, msg):
"""Stream RAG response."""
async for event in self.graph.astream({"question": question, "context": []}, stream_mode=["messages"]):
_event_name, (message_chunk, _metadata) = event
if message_chunk.content:
await msg.stream_token(message_chunk.content)
await msg.send()
# Run RAG with CLI (no streaming)
def main():
"""Test the RAG graph."""
load_dotenv()
rag_graph = RagGraph()
# rag_graph.update_vector_store("data/scraped/clean", replace_documents=False)
rag_graph.create_rag_graph()
response = rag_graph.run("What is Shopify Flow?")
print(response["response"])
if __name__ == "__main__":
main()