krishanusinha20 commited on
Commit
c172e70
·
verified ·
1 Parent(s): 16a7c80

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +191 -0
app.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ import chainlit as cl
4
+ from dotenv import load_dotenv
5
+ from operator import itemgetter
6
+ from langchain_huggingface import HuggingFaceEndpoint
7
+ from langchain_community.document_loaders import TextLoader
8
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
9
+ from langchain_community.vectorstores import FAISS
10
+ from langchain_huggingface import HuggingFaceEndpointEmbeddings
11
+ from langchain_core.prompts import PromptTemplate
12
+ from langchain.schema.output_parser import StrOutputParser
13
+ from langchain.schema.runnable import RunnablePassthrough
14
+ from langchain.schema.runnable.config import RunnableConfig
15
+ from tqdm.asyncio import tqdm_asyncio
16
+ import asyncio
17
+ from tqdm.asyncio import tqdm
18
+
19
+ # GLOBAL SCOPE - ENTIRE APPLICATION HAS ACCESS TO VALUES SET IN THIS SCOPE #
20
+ # ---- ENV VARIABLES ---- #
21
+ """
22
+ This function will load our environment file (.env) if it is present.
23
+
24
+ NOTE: Make sure that .env is in your .gitignore file - it is by default, but please ensure it remains there.
25
+ """
26
+ load_dotenv()
27
+
28
+ """
29
+ We will load our environment variables here.
30
+ """
31
+ HF_LLM_ENDPOINT = os.environ["HF_LLM_ENDPOINT"]
32
+ HF_EMBED_ENDPOINT = os.environ["HF_EMBED_ENDPOINT"]
33
+ HF_TOKEN = os.environ["HF_TOKEN"]
34
+
35
+ # ---- GLOBAL DECLARATIONS ---- #
36
+
37
+ # -- RETRIEVAL -- #
38
+ """
39
+ 1. Load Documents from Text File
40
+ 2. Split Documents into Chunks
41
+ 3. Load HuggingFace Embeddings (remember to use the URL we set above)
42
+ 4. Index Files if they do not exist, otherwise load the vectorstore
43
+ """
44
+ document_loader = TextLoader("./data/paul_graham_essays.txt")
45
+ documents = document_loader.load()
46
+
47
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=30)
48
+ split_documents = text_splitter.split_documents(documents)
49
+
50
+ hf_embeddings = HuggingFaceEndpointEmbeddings(
51
+ model=HF_EMBED_ENDPOINT,
52
+ task="feature-extraction",
53
+ huggingfacehub_api_token=HF_TOKEN,
54
+ )
55
+
56
+ async def add_documents_async(vectorstore, documents):
57
+ await vectorstore.aadd_documents(documents)
58
+
59
+ async def process_batch(vectorstore, batch, is_first_batch, pbar):
60
+ if is_first_batch:
61
+ result = await FAISS.afrom_documents(batch, hf_embeddings)
62
+ else:
63
+ await add_documents_async(vectorstore, batch)
64
+ result = vectorstore
65
+ pbar.update(len(batch))
66
+ return result
67
+
68
+ async def main():
69
+ print("Indexing Files")
70
+
71
+ vectorstore = None
72
+ batch_size = 32
73
+
74
+ batches = [split_documents[i:i+batch_size] for i in range(0, len(split_documents), batch_size)]
75
+
76
+ async def process_all_batches():
77
+ nonlocal vectorstore
78
+ tasks = []
79
+ pbars = []
80
+
81
+ for i, batch in enumerate(batches):
82
+ pbar = tqdm(total=len(batch), desc=f"Batch {i+1}/{len(batches)}", position=i)
83
+ pbars.append(pbar)
84
+
85
+ if i == 0:
86
+ vectorstore = await process_batch(None, batch, True, pbar)
87
+ else:
88
+ tasks.append(process_batch(vectorstore, batch, False, pbar))
89
+
90
+ if tasks:
91
+ await asyncio.gather(*tasks)
92
+
93
+ for pbar in pbars:
94
+ pbar.close()
95
+
96
+ await process_all_batches()
97
+
98
+ hf_retriever = vectorstore.as_retriever()
99
+ print("\nIndexing complete. Vectorstore is ready for use.")
100
+ return hf_retriever
101
+
102
+ async def run():
103
+ retriever = await main()
104
+ return retriever
105
+
106
+ hf_retriever = asyncio.run(run())
107
+
108
+ # -- AUGMENTED -- #
109
+ """
110
+ 1. Define a String Template
111
+ 2. Create a Prompt Template from the String Template
112
+ """
113
+ RAG_PROMPT_TEMPLATE = """\
114
+ <|start_header_id|>system<|end_header_id|>
115
+ You are a helpful assistant. You answer user questions based on provided context. If you can't answer the question with the provided context, say you don't know.<|eot_id|>
116
+
117
+ <|start_header_id|>user<|end_header_id|>
118
+ User Query:
119
+ {query}
120
+
121
+ Context:
122
+ {context}<|eot_id|>
123
+
124
+ <|start_header_id|>assistant<|end_header_id|>
125
+ """
126
+
127
+ rag_prompt = PromptTemplate.from_template(RAG_PROMPT_TEMPLATE)
128
+
129
+ # -- GENERATION -- #
130
+ """
131
+ 1. Create a HuggingFaceEndpoint for the LLM
132
+ """
133
+ hf_llm = HuggingFaceEndpoint(
134
+ endpoint_url=HF_LLM_ENDPOINT,
135
+ max_new_tokens=512,
136
+ top_k=10,
137
+ top_p=0.95,
138
+ temperature=0.3,
139
+ repetition_penalty=1.15,
140
+ huggingfacehub_api_token=HF_TOKEN,
141
+ )
142
+
143
+ @cl.author_rename
144
+ def rename(original_author: str):
145
+ """
146
+ This function can be used to rename the 'author' of a message.
147
+
148
+ In this case, we're overriding the 'Assistant' author to be 'Paul Graham Essay Bot'.
149
+ """
150
+ rename_dict = {
151
+ "Assistant" : "Paul Graham Essay Bot"
152
+ }
153
+ return rename_dict.get(original_author, original_author)
154
+
155
+ @cl.on_chat_start
156
+ async def start_chat():
157
+ """
158
+ This function will be called at the start of every user session.
159
+
160
+ We will build our LCEL RAG chain here, and store it in the user session.
161
+
162
+ The user session is a dictionary that is unique to each user session, and is stored in the memory of the server.
163
+ """
164
+
165
+ lcel_rag_chain = (
166
+ {"context": itemgetter("query") | hf_retriever, "query": itemgetter("query")}
167
+ | rag_prompt | hf_llm
168
+ )
169
+
170
+ cl.user_session.set("lcel_rag_chain", lcel_rag_chain)
171
+
172
+ @cl.on_message
173
+ async def main(message: cl.Message):
174
+ """
175
+ This function will be called every time a message is recieved from a session.
176
+
177
+ We will use the LCEL RAG chain to generate a response to the user query.
178
+
179
+ The LCEL RAG chain is stored in the user session, and is unique to each user session - this is why we can access it here.
180
+ """
181
+ lcel_rag_chain = cl.user_session.get("lcel_rag_chain")
182
+
183
+ msg = cl.Message(content="")
184
+
185
+ for chunk in await cl.make_async(lcel_rag_chain.stream)(
186
+ {"query": message.content},
187
+ config=RunnableConfig(callbacks=[cl.LangchainCallbackHandler()]),
188
+ ):
189
+ await msg.stream_token(chunk)
190
+
191
+ await msg.send()