Spaces:
Sleeping
Sleeping
File size: 4,932 Bytes
465a7e3 c1367c2 465a7e3 95b0fa1 465a7e3 95b0fa1 465a7e3 95b0fa1 465a7e3 95b0fa1 465a7e3 c1367c2 465a7e3 c1367c2 465a7e3 c1367c2 465a7e3 c1367c2 465a7e3 95b0fa1 465a7e3 95b0fa1 465a7e3 95b0fa1 465a7e3 95b0fa1 465a7e3 95b0fa1 465a7e3 95b0fa1 465a7e3 95b0fa1 465a7e3 95b0fa1 465a7e3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
from dotenv import load_dotenv
from typing_extensions import List, TypedDict
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from langchain_core.documents import Document
from langchain_core.prompts import ChatPromptTemplate
from langchain_qdrant import QdrantVectorStore
from langchain_huggingface import HuggingFaceEmbeddings
from langgraph.graph import START, StateGraph
from langchain.prompts import ChatPromptTemplate
from langchain_community.document_loaders import DirectoryLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from qdrant_client.http.models import Distance, VectorParams
# Necessary for dependencies for DirectoryLoader
import nltk
nltk.download('punkt_tab')
nltk.download('averaged_perceptron_tagger_eng')
# Chunk configuration
CHUNK_SIZE = 1000
CHUNK_OVERLAP = CHUNK_SIZE // 2
# RAG prompt template
RAG_PROMPT = """\
You are a helpful assistant who helps Shopify merchants automate their businesses.
Your goal is to provide a helpful response to the merchant's question in straight forward, non technical language.
Try to be brief and to the point, but explain technical jargon.
You must only use the provided context, and cannot use your own knowledge.
### Question
{question}
### Context
{context}
"""
class RagGraph:
def __init__(self, qdrant_client, use_finetuned_embeddings=False):
self.llm = ChatOpenAI(model="gpt-4-turbo-preview", streaming=True)
self.collection_name = "rag_collection" if not use_finetuned_embeddings else "rag_collection_finetuned"
self.embeddings_model = OpenAIEmbeddings(model="text-embedding-3-small") \
if not use_finetuned_embeddings else HuggingFaceEmbeddings(model_name="thomfoolery/AIE5-MidTerm-finetuned-embeddings")
self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP)
self.rag_prompt = ChatPromptTemplate.from_template(RAG_PROMPT)
self.qdrant_client = qdrant_client
does_collection_exist = self.qdrant_client.collection_exists(collection_name=self.collection_name)
dimension_size = 1536 if not use_finetuned_embeddings else 1024
print(f"Collection {self.collection_name} exists: {does_collection_exist}")
if not does_collection_exist:
qdrant_client.create_collection(
collection_name=self.collection_name,
vectors_config=VectorParams(size=dimension_size, distance=Distance.COSINE),
)
self.vector_store = QdrantVectorStore(
client=qdrant_client,
collection_name=self.collection_name,
embedding=self.embeddings_model,
)
if not does_collection_exist:
loader = DirectoryLoader("data/scraped/clean", glob="*.txt")
documents = self.text_splitter.split_documents(loader.load())
self.vector_store.add_documents(documents=documents)
self.vector_db_retriever = self.vector_store.as_retriever(search_kwargs={"k": 5})
self.graph = None
self.create()
def create(self):
"""Create the RAG graph."""
class State(TypedDict):
"""State for the conversation."""
question: str
context: List[Document]
def retrieve(state):
question = state["question"]
context = self.vector_db_retriever.invoke(question)
return {"question": state["question"], "context": context}
async def stream(state):
"""LangGraph node that streams responses"""
question = state["question"]
context = "\n\n".join(doc.page_content for doc in state["context"])
messages = self.rag_prompt.format_messages(question=question, context=context)
async for chunk in self.llm.astream(messages):
yield {"content": chunk.content}
graph_builder = StateGraph(State).add_sequence([retrieve, stream])
graph_builder.add_edge(START, "retrieve")
self.graph = graph_builder.compile()
def run(self, question):
"""Invoke RAG response without streaming."""
chunks = self.vector_db_retriever.invoke(question)
context = "\n\n".join(doc.page_content for doc in chunks)
messages = self.rag_prompt.format_messages(question=question, context=context)
response = self.llm.invoke(messages)
return {
"response": response.content,
"context": chunks
}
async def stream(self, question, msg):
"""Stream RAG response."""
async for event in self.graph.astream({"question": question, "context": []}, stream_mode=["messages"]):
_event_name, (message_chunk, _metadata) = event
if message_chunk.content:
await msg.stream_token(message_chunk.content)
await msg.send()
# Run RAG with CLI (no streaming)
def main():
"""Test the RAG graph."""
load_dotenv()
rag_graph = RagGraph()
# rag_graph.update_vector_store("data/scraped/clean", replace_documents=False)
rag_graph.create_rag_graph()
response = rag_graph.run("What is Shopify Flow?")
print(response["response"])
if __name__ == "__main__":
main()
|