Spaces:
Runtime error
Runtime error
Adds doc store global and top k parameter
Browse files- core/pipelines.py +20 -5
- interface/components.py +2 -2
- interface/pages.py +3 -1
core/pipelines.py
CHANGED
|
@@ -14,8 +14,14 @@ import os
|
|
| 14 |
data_path = "data/"
|
| 15 |
os.makedirs(data_path, exist_ok=True)
|
| 16 |
|
|
|
|
| 17 |
|
| 18 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
"""
|
| 20 |
**Keyword Search Pipeline**
|
| 21 |
|
|
@@ -26,8 +32,10 @@ def keyword_search(index="documents", split_word_length=100, audio_output=False)
|
|
| 26 |
- Documents that have more lexical overlap with the query are more likely to be relevant
|
| 27 |
- Words that occur in fewer documents are more significant than words that occur in many documents
|
| 28 |
"""
|
| 29 |
-
document_store
|
| 30 |
-
|
|
|
|
|
|
|
| 31 |
processor = PreProcessor(
|
| 32 |
clean_empty_lines=True,
|
| 33 |
clean_whitespace=True,
|
|
@@ -65,6 +73,7 @@ def dense_passage_retrieval(
|
|
| 65 |
split_word_length=100,
|
| 66 |
query_embedding_model="facebook/dpr-question_encoder-single-nq-base",
|
| 67 |
passage_embedding_model="facebook/dpr-ctx_encoder-single-nq-base",
|
|
|
|
| 68 |
audio_output=False,
|
| 69 |
):
|
| 70 |
"""
|
|
@@ -76,11 +85,14 @@ def dense_passage_retrieval(
|
|
| 76 |
- One BERT base model to encode queries
|
| 77 |
- Ranking of documents done by dot product similarity between query and document embeddings
|
| 78 |
"""
|
| 79 |
-
document_store
|
|
|
|
|
|
|
| 80 |
dpr_retriever = DensePassageRetriever(
|
| 81 |
document_store=document_store,
|
| 82 |
query_embedding_model=query_embedding_model,
|
| 83 |
passage_embedding_model=passage_embedding_model,
|
|
|
|
| 84 |
)
|
| 85 |
processor = PreProcessor(
|
| 86 |
clean_empty_lines=True,
|
|
@@ -121,6 +133,7 @@ def dense_passage_retrieval_ranker(
|
|
| 121 |
query_embedding_model="facebook/dpr-question_encoder-single-nq-base",
|
| 122 |
passage_embedding_model="facebook/dpr-ctx_encoder-single-nq-base",
|
| 123 |
ranker_model="cross-encoder/ms-marco-MiniLM-L-12-v2",
|
|
|
|
| 124 |
audio_output=False,
|
| 125 |
):
|
| 126 |
"""
|
|
@@ -137,8 +150,10 @@ def dense_passage_retrieval_ranker(
|
|
| 137 |
split_word_length=split_word_length,
|
| 138 |
query_embedding_model=query_embedding_model,
|
| 139 |
passage_embedding_model=passage_embedding_model,
|
|
|
|
|
|
|
| 140 |
)
|
| 141 |
-
ranker = SentenceTransformersRanker(model_name_or_path=ranker_model)
|
| 142 |
|
| 143 |
search_pipeline.add_node(ranker, name="Ranker", inputs=["DPRRetriever"])
|
| 144 |
|
|
|
|
| 14 |
data_path = "data/"
|
| 15 |
os.makedirs(data_path, exist_ok=True)
|
| 16 |
|
| 17 |
+
index = "documents"
|
| 18 |
|
| 19 |
+
document_store = InMemoryDocumentStore(index=index)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def keyword_search(
|
| 23 |
+
index="documents", split_word_length=100, top_k=10, audio_output=False
|
| 24 |
+
):
|
| 25 |
"""
|
| 26 |
**Keyword Search Pipeline**
|
| 27 |
|
|
|
|
| 32 |
- Documents that have more lexical overlap with the query are more likely to be relevant
|
| 33 |
- Words that occur in fewer documents are more significant than words that occur in many documents
|
| 34 |
"""
|
| 35 |
+
global document_store
|
| 36 |
+
if index != document_store.index:
|
| 37 |
+
document_store = InMemoryDocumentStore(index=index)
|
| 38 |
+
keyword_retriever = TfidfRetriever(document_store=(document_store), top_k=top_k)
|
| 39 |
processor = PreProcessor(
|
| 40 |
clean_empty_lines=True,
|
| 41 |
clean_whitespace=True,
|
|
|
|
| 73 |
split_word_length=100,
|
| 74 |
query_embedding_model="facebook/dpr-question_encoder-single-nq-base",
|
| 75 |
passage_embedding_model="facebook/dpr-ctx_encoder-single-nq-base",
|
| 76 |
+
top_k=10,
|
| 77 |
audio_output=False,
|
| 78 |
):
|
| 79 |
"""
|
|
|
|
| 85 |
- One BERT base model to encode queries
|
| 86 |
- Ranking of documents done by dot product similarity between query and document embeddings
|
| 87 |
"""
|
| 88 |
+
global document_store
|
| 89 |
+
if index != document_store.index:
|
| 90 |
+
document_store = InMemoryDocumentStore(index=index)
|
| 91 |
dpr_retriever = DensePassageRetriever(
|
| 92 |
document_store=document_store,
|
| 93 |
query_embedding_model=query_embedding_model,
|
| 94 |
passage_embedding_model=passage_embedding_model,
|
| 95 |
+
top_k=top_k,
|
| 96 |
)
|
| 97 |
processor = PreProcessor(
|
| 98 |
clean_empty_lines=True,
|
|
|
|
| 133 |
query_embedding_model="facebook/dpr-question_encoder-single-nq-base",
|
| 134 |
passage_embedding_model="facebook/dpr-ctx_encoder-single-nq-base",
|
| 135 |
ranker_model="cross-encoder/ms-marco-MiniLM-L-12-v2",
|
| 136 |
+
top_k=10,
|
| 137 |
audio_output=False,
|
| 138 |
):
|
| 139 |
"""
|
|
|
|
| 150 |
split_word_length=split_word_length,
|
| 151 |
query_embedding_model=query_embedding_model,
|
| 152 |
passage_embedding_model=passage_embedding_model,
|
| 153 |
+
# top_k high to allow better recall, the ranker will handle the precision
|
| 154 |
+
top_k=10000000,
|
| 155 |
)
|
| 156 |
+
ranker = SentenceTransformersRanker(model_name_or_path=ranker_model, top_k=top_k)
|
| 157 |
|
| 158 |
search_pipeline.add_node(ranker, name="Ranker", inputs=["DPRRetriever"])
|
| 159 |
|
interface/components.py
CHANGED
|
@@ -27,9 +27,9 @@ def component_select_pipeline(container):
|
|
| 27 |
elif isinstance(value, bool):
|
| 28 |
value = st.checkbox(parameter, value)
|
| 29 |
elif isinstance(value, int):
|
| 30 |
-
value = int(st.number_input(parameter, value))
|
| 31 |
elif isinstance(value, float):
|
| 32 |
-
value = float(st.number_input(parameter, value))
|
| 33 |
pipeline_func_parameters[index_pipe][parameter] = value
|
| 34 |
if (
|
| 35 |
st.session_state["pipeline"] is None
|
|
|
|
| 27 |
elif isinstance(value, bool):
|
| 28 |
value = st.checkbox(parameter, value)
|
| 29 |
elif isinstance(value, int):
|
| 30 |
+
value = int(st.number_input(parameter, value=value))
|
| 31 |
elif isinstance(value, float):
|
| 32 |
+
value = float(st.number_input(parameter, value=value))
|
| 33 |
pipeline_func_parameters[index_pipe][parameter] = value
|
| 34 |
if (
|
| 35 |
st.session_state["pipeline"] is None
|
interface/pages.py
CHANGED
|
@@ -88,7 +88,9 @@ def page_index(container):
|
|
| 88 |
index_results = None
|
| 89 |
if st.button("Index"):
|
| 90 |
index_results = index(
|
| 91 |
-
corpus,
|
|
|
|
|
|
|
| 92 |
)
|
| 93 |
st.session_state["doc_id"] = doc_id
|
| 94 |
if index_results:
|
|
|
|
| 88 |
index_results = None
|
| 89 |
if st.button("Index"):
|
| 90 |
index_results = index(
|
| 91 |
+
documents=corpus,
|
| 92 |
+
pipeline=st.session_state["pipeline"]["index_pipeline"],
|
| 93 |
+
clear_index=clear_index,
|
| 94 |
)
|
| 95 |
st.session_state["doc_id"] = doc_id
|
| 96 |
if index_results:
|