Faizan15 commited on
Commit
7136147
·
1 Parent(s): 70182f3

Create main.py

Browse files
Files changed (1) hide show
  1. main.py +208 -0
main.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Tuple
2
+ from pathlib import Path
3
+ import os
4
+ import subprocess
5
+
6
+ from dotenv import load_dotenv
7
+ from haystack.preview import Pipeline
8
+ from haystack.preview.dataclasses import GeneratedAnswer
9
+ from haystack.preview.components.retrievers import MemoryBM25Retriever
10
+ from haystack.preview.components.generators.openai.gpt import GPTGenerator
11
+ from haystack.preview.components.builders.answer_builder import AnswerBuilder
12
+ from haystack.preview.components.builders.prompt_builder import PromptBuilder
13
+ from haystack.preview.components.preprocessors import (
14
+ DocumentCleaner,
15
+ TextDocumentSplitter,
16
+ )
17
+ from haystack.preview.components.writers import DocumentWriter
18
+ from haystack.preview.components.file_converters import TextFileToDocument
19
+ from haystack.preview.document_stores.memory import MemoryDocumentStore
20
+ import streamlit as st
21
+
22
+ # Load the environment variables, we're going to need it for OpenAI
23
+ load_dotenv()
24
+
25
+ # This is the list of documentation that we're going to fetch
26
+ DOCUMENTATIONS = [
27
+ (
28
+ "DocArray",
29
+ "https://github.com/docarray/docarray",
30
+ "./docs/**/*.md",
31
+ ),
32
+ (
33
+ "Streamlit",
34
+ "https://github.com/streamlit/docs",
35
+ "./content/**/*.md",
36
+ ),
37
+ (
38
+ "Jinja",
39
+ "https://github.com/pallets/jinja",
40
+ "./docs/**/*.rst",
41
+ ),
42
+ (
43
+ "Pandas",
44
+ "https://github.com/pandas-dev/pandas",
45
+ "./doc/source/**/*.rst",
46
+ ),
47
+ (
48
+ "Elasticsearch",
49
+ "https://github.com/elastic/elasticsearch",
50
+ "./docs/**/*.asciidoc",
51
+ ),
52
+ (
53
+ "NumPy",
54
+ "https://github.com/numpy/numpy",
55
+ "./doc/**/*.rst",
56
+ ),
57
+ ]
58
+
59
+ DOCS_PATH = Path(__file__).parent / "downloaded_docs"
60
+
61
+
62
+ @st.cache_data(show_spinner=False)
63
+ def fetch(documentations: List[Tuple[str, str, str]]):
64
+ files = []
65
+ # Create the docs path if it doesn't exist
66
+ DOCS_PATH.mkdir(parents=True, exist_ok=True)
67
+
68
+ for name, url, pattern in documentations:
69
+ st.write(f"Fetching {name} repository")
70
+ repo = DOCS_PATH / name
71
+ # Attempt cloning only if it doesn't exist
72
+ if not repo.exists():
73
+ subprocess.run(["git", "clone", "--depth", "1", url, str(repo)], check=True)
74
+ res = subprocess.run(
75
+ ["git", "rev-parse", "--abbrev-ref", "HEAD"],
76
+ check=True,
77
+ capture_output=True,
78
+ encoding="utf-8",
79
+ cwd=repo,
80
+ )
81
+ branch = res.stdout.strip()
82
+ for p in repo.glob(pattern):
83
+ data = {
84
+ "path": p,
85
+ "metadata": {
86
+ "url_source": f"{url}/tree/{branch}/{p.relative_to(repo)}",
87
+ "suffix": p.suffix,
88
+ },
89
+ }
90
+ files.append(data)
91
+
92
+ return files
93
+
94
+
95
+ @st.cache_resource(show_spinner=False)
96
+ def document_store():
97
+ # We're going to store the processed documents in here
98
+ return MemoryDocumentStore()
99
+
100
+
101
+ @st.cache_resource(show_spinner=False)
102
+ def index_files(files):
103
+ # We create some components
104
+ text_converter = TextFileToDocument(progress_bar=False)
105
+ document_cleaner = DocumentCleaner()
106
+ document_splitter = TextDocumentSplitter()
107
+ document_writer = DocumentWriter(
108
+ document_store=document_store(), policy="overwrite"
109
+ )
110
+
111
+ # And our pipeline
112
+ indexing_pipeline = Pipeline()
113
+ indexing_pipeline.add_component("converter", text_converter)
114
+ indexing_pipeline.add_component("cleaner", document_cleaner)
115
+ indexing_pipeline.add_component("splitter", document_splitter)
116
+ indexing_pipeline.add_component("writer", document_writer)
117
+ indexing_pipeline.connect("converter", "cleaner")
118
+ indexing_pipeline.connect("cleaner", "splitter")
119
+ indexing_pipeline.connect("splitter", "writer")
120
+
121
+ # And now we save the documentation in our MemoryDocumentStore
122
+ paths = []
123
+ metadata = []
124
+ for f in files:
125
+ paths.append(f["path"])
126
+ metadata.append(f["metadata"])
127
+ indexing_pipeline.run(
128
+ {
129
+ "converter": {
130
+ "paths": paths,
131
+ "metadata": metadata,
132
+ }
133
+ }
134
+ )
135
+
136
+
137
+ def search(question: str) -> GeneratedAnswer:
138
+ retriever = MemoryBM25Retriever(document_store=document_store(), top_k=5)
139
+
140
+ template = (
141
+ "Take a deep breath and think then answer given the context"
142
+ "Context: {{ documents|map(attribute='text')|replace('\n', ' ')|join(';') }}"
143
+ "Question: {{ query }}"
144
+ "Answer:"
145
+ )
146
+ prompt_builder = PromptBuilder(template)
147
+
148
+ OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "")
149
+ generator = GPTGenerator(api_key=OPENAI_API_KEY)
150
+ answer_builder = AnswerBuilder()
151
+
152
+ query_pipeline = Pipeline()
153
+
154
+ query_pipeline.add_component("docs_retriever", retriever)
155
+ query_pipeline.add_component("prompt_builder", prompt_builder)
156
+ query_pipeline.add_component("gpt35", generator)
157
+ query_pipeline.add_component("answer_builder", answer_builder)
158
+
159
+ query_pipeline.connect("docs_retriever.documents", "prompt_builder.documents")
160
+ query_pipeline.connect("prompt_builder.prompt", "gpt35.prompt")
161
+ query_pipeline.connect("docs_retriever.documents", "answer_builder.documents")
162
+ query_pipeline.connect("gpt35.replies", "answer_builder.replies")
163
+ res = query_pipeline.run(
164
+ {
165
+ "docs_retriever": {"query": question},
166
+ "prompt_builder": {"query": question},
167
+ "answer_builder": {"query": question},
168
+ }
169
+ )
170
+ return res["answer_builder"]["answers"][0]
171
+
172
+
173
+ with st.status(
174
+ "Downloading documentation files...",
175
+ expanded=st.session_state.get("expanded", True),
176
+ ) as status:
177
+ files = fetch(DOCUMENTATIONS)
178
+ status.update(label="Indexing documentation...")
179
+ index_files(files)
180
+ status.update(
181
+ label="Download and indexing complete!", state="complete", expanded=False
182
+ )
183
+ st.session_state["expanded"] = False
184
+
185
+
186
+ st.header("🔎 Documentation finder", divider="rainbow")
187
+
188
+ st.caption(
189
+ f"Use this to search answers for {', '.join([d[0] for d in DOCUMENTATIONS])}"
190
+ )
191
+
192
+ if question := st.text_input(
193
+ label="What do you need to know?", placeholder="What is a DataFrame?"
194
+ ):
195
+ with st.spinner("Waiting"):
196
+ answer = search(question)
197
+
198
+ if not st.session_state.get("run_once", False):
199
+ st.balloons()
200
+ st.session_state["run_once"] = True
201
+
202
+ st.markdown(answer.data)
203
+ with st.expander("See sources:"):
204
+ for document in answer.documents:
205
+ url_source = document.metadata.get("url_source", "")
206
+ st.write(url_source)
207
+ st.text(document.text)
208
+ st.divider()