File size: 4,932 Bytes
465a7e3
 
 
 
 
 
 
 
 
 
c1367c2
465a7e3
 
 
 
 
 
 
 
95b0fa1
465a7e3
 
 
 
95b0fa1
465a7e3
 
 
95b0fa1
465a7e3
95b0fa1
 
 
465a7e3
 
 
 
 
 
 
 
 
 
c1367c2
465a7e3
c1367c2
 
 
 
465a7e3
 
 
 
 
c1367c2
 
 
465a7e3
 
 
 
c1367c2
465a7e3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95b0fa1
 
465a7e3
95b0fa1
465a7e3
 
 
 
 
 
 
 
95b0fa1
 
465a7e3
 
 
 
 
 
 
 
 
 
 
 
 
 
95b0fa1
 
 
 
465a7e3
95b0fa1
 
 
 
465a7e3
 
95b0fa1
465a7e3
95b0fa1
465a7e3
 
 
 
95b0fa1
 
 
465a7e3
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
from dotenv import load_dotenv

from typing_extensions import List, TypedDict

from langchain_openai import ChatOpenAI, OpenAIEmbeddings

from langchain_core.documents import Document
from langchain_core.prompts import ChatPromptTemplate

from langchain_qdrant import QdrantVectorStore
from langchain_huggingface import HuggingFaceEmbeddings

from langgraph.graph import START, StateGraph
from langchain.prompts import ChatPromptTemplate
from langchain_community.document_loaders import DirectoryLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter

from qdrant_client.http.models import Distance, VectorParams

# Necessary for dependencies for DirectoryLoader
import nltk
nltk.download('punkt_tab')
nltk.download('averaged_perceptron_tagger_eng')

# Chunk configuration
CHUNK_SIZE = 1000
CHUNK_OVERLAP = CHUNK_SIZE // 2

# RAG prompt template
RAG_PROMPT = """\
You are a helpful assistant who helps Shopify merchants automate their businesses.
Your goal is to provide a helpful response to the merchant's question in straight forward, non technical language.
Try to be brief and to the point, but explain technical jargon. 
You must only use the provided context, and cannot use your own knowledge.

### Question
{question}

### Context
{context}
"""

class RagGraph:
  def __init__(self, qdrant_client, use_finetuned_embeddings=False):
    self.llm = ChatOpenAI(model="gpt-4-turbo-preview", streaming=True)
    self.collection_name = "rag_collection" if not use_finetuned_embeddings else "rag_collection_finetuned"
    self.embeddings_model = OpenAIEmbeddings(model="text-embedding-3-small") \
      if not use_finetuned_embeddings else HuggingFaceEmbeddings(model_name="thomfoolery/AIE5-MidTerm-finetuned-embeddings")
    
    self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP)
    self.rag_prompt = ChatPromptTemplate.from_template(RAG_PROMPT)
    self.qdrant_client = qdrant_client

    does_collection_exist = self.qdrant_client.collection_exists(collection_name=self.collection_name)
    dimension_size = 1536 if not use_finetuned_embeddings else 1024

    print(f"Collection {self.collection_name} exists: {does_collection_exist}")

    if not does_collection_exist:
      qdrant_client.create_collection(
        collection_name=self.collection_name,
        vectors_config=VectorParams(size=dimension_size, distance=Distance.COSINE),
      )

    self.vector_store = QdrantVectorStore(
      client=qdrant_client,
      collection_name=self.collection_name,
      embedding=self.embeddings_model,
    )

    if not does_collection_exist:
      loader = DirectoryLoader("data/scraped/clean", glob="*.txt")
      documents = self.text_splitter.split_documents(loader.load())
      self.vector_store.add_documents(documents=documents)
  
    self.vector_db_retriever = self.vector_store.as_retriever(search_kwargs={"k": 5})    
    self.graph = None

    self.create()
    
  def create(self):
    """Create the RAG graph."""
    class State(TypedDict):
      """State for the conversation."""
      question: str
      context: List[Document]

    def retrieve(state):
      question = state["question"]
      context = self.vector_db_retriever.invoke(question)
      return {"question": state["question"], "context": context} 

    async def stream(state):
      """LangGraph node that streams responses"""
      question = state["question"]
      context = "\n\n".join(doc.page_content for doc in state["context"])
      messages = self.rag_prompt.format_messages(question=question, context=context)
      async for chunk in self.llm.astream(messages):
        yield {"content": chunk.content}

    graph_builder = StateGraph(State).add_sequence([retrieve, stream])
    graph_builder.add_edge(START, "retrieve")
    self.graph = graph_builder.compile()

  def run(self, question):
    """Invoke RAG response without streaming."""
    chunks = self.vector_db_retriever.invoke(question)
    context = "\n\n".join(doc.page_content for doc in chunks)
    messages = self.rag_prompt.format_messages(question=question, context=context)
    response = self.llm.invoke(messages)
    return {
      "response": response.content,
      "context": chunks
    }
  
  async def stream(self, question, msg):
    """Stream RAG response."""
    async for event in self.graph.astream({"question": question, "context": []}, stream_mode=["messages"]):
      _event_name, (message_chunk, _metadata) = event 
      if message_chunk.content:
        await msg.stream_token(message_chunk.content)
    await msg.send()
  


# Run RAG with CLI (no streaming)
def main():
  """Test the RAG graph."""
  load_dotenv()
  rag_graph = RagGraph()
  # rag_graph.update_vector_store("data/scraped/clean", replace_documents=False)
  rag_graph.create_rag_graph()
  response = rag_graph.run("What is Shopify Flow?")
  print(response["response"])

if __name__ == "__main__":
    main()