Spaces:
Sleeping
Sleeping
from dotenv import load_dotenv | |
from typing_extensions import List, TypedDict | |
from langchain_openai import ChatOpenAI, OpenAIEmbeddings | |
from langchain_core.documents import Document | |
from langchain_core.prompts import ChatPromptTemplate | |
from langchain_qdrant import QdrantVectorStore | |
from langchain_huggingface import HuggingFaceEmbeddings | |
from langgraph.graph import START, StateGraph | |
from langchain.prompts import ChatPromptTemplate | |
from langchain_community.document_loaders import DirectoryLoader | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from qdrant_client.http.models import Distance, VectorParams | |
# Necessary for dependencies for DirectoryLoader | |
import nltk | |
nltk.download('punkt_tab') | |
nltk.download('averaged_perceptron_tagger_eng') | |
# Chunk configuration | |
CHUNK_SIZE = 1000 | |
CHUNK_OVERLAP = CHUNK_SIZE // 2 | |
# RAG prompt template | |
RAG_PROMPT = """\ | |
You are a helpful assistant who helps Shopify merchants automate their businesses. | |
Your goal is to provide a helpful response to the merchant's question in straight forward, non technical language. | |
Try to be brief and to the point, but explain technical jargon. | |
You must only use the provided context, and cannot use your own knowledge. | |
### Question | |
{question} | |
### Context | |
{context} | |
""" | |
class RagGraph: | |
def __init__(self, qdrant_client, use_finetuned_embeddings=False): | |
self.llm = ChatOpenAI(model="gpt-4-turbo-preview", streaming=True) | |
self.collection_name = "rag_collection" if not use_finetuned_embeddings else "rag_collection_finetuned" | |
self.embeddings_model = OpenAIEmbeddings(model="text-embedding-3-small") \ | |
if not use_finetuned_embeddings else HuggingFaceEmbeddings(model_name="thomfoolery/AIE5-MidTerm-finetuned-embeddings") | |
self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP) | |
self.rag_prompt = ChatPromptTemplate.from_template(RAG_PROMPT) | |
self.qdrant_client = qdrant_client | |
does_collection_exist = self.qdrant_client.collection_exists(collection_name=self.collection_name) | |
dimension_size = 1536 if not use_finetuned_embeddings else 1024 | |
print(f"Collection {self.collection_name} exists: {does_collection_exist}") | |
if not does_collection_exist: | |
qdrant_client.create_collection( | |
collection_name=self.collection_name, | |
vectors_config=VectorParams(size=dimension_size, distance=Distance.COSINE), | |
) | |
self.vector_store = QdrantVectorStore( | |
client=qdrant_client, | |
collection_name=self.collection_name, | |
embedding=self.embeddings_model, | |
) | |
if not does_collection_exist: | |
loader = DirectoryLoader("data/scraped/clean", glob="*.txt") | |
documents = self.text_splitter.split_documents(loader.load()) | |
self.vector_store.add_documents(documents=documents) | |
self.vector_db_retriever = self.vector_store.as_retriever(search_kwargs={"k": 5}) | |
self.graph = None | |
self.create() | |
def create(self): | |
"""Create the RAG graph.""" | |
class State(TypedDict): | |
"""State for the conversation.""" | |
question: str | |
context: List[Document] | |
def retrieve(state): | |
question = state["question"] | |
context = self.vector_db_retriever.invoke(question) | |
return {"question": state["question"], "context": context} | |
async def stream(state): | |
"""LangGraph node that streams responses""" | |
question = state["question"] | |
context = "\n\n".join(doc.page_content for doc in state["context"]) | |
messages = self.rag_prompt.format_messages(question=question, context=context) | |
async for chunk in self.llm.astream(messages): | |
yield {"content": chunk.content} | |
graph_builder = StateGraph(State).add_sequence([retrieve, stream]) | |
graph_builder.add_edge(START, "retrieve") | |
self.graph = graph_builder.compile() | |
def run(self, question): | |
"""Invoke RAG response without streaming.""" | |
chunks = self.vector_db_retriever.invoke(question) | |
context = "\n\n".join(doc.page_content for doc in chunks) | |
messages = self.rag_prompt.format_messages(question=question, context=context) | |
response = self.llm.invoke(messages) | |
return { | |
"response": response.content, | |
"context": chunks | |
} | |
async def stream(self, question, msg): | |
"""Stream RAG response.""" | |
async for event in self.graph.astream({"question": question, "context": []}, stream_mode=["messages"]): | |
_event_name, (message_chunk, _metadata) = event | |
if message_chunk.content: | |
await msg.stream_token(message_chunk.content) | |
await msg.send() | |
# Run RAG with CLI (no streaming) | |
def main(): | |
"""Test the RAG graph.""" | |
load_dotenv() | |
rag_graph = RagGraph() | |
# rag_graph.update_vector_store("data/scraped/clean", replace_documents=False) | |
rag_graph.create_rag_graph() | |
response = rag_graph.run("What is Shopify Flow?") | |
print(response["response"]) | |
if __name__ == "__main__": | |
main() | |