File size: 4,378 Bytes
2359bda
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
"""
This examples demonstrates the setup for Question-Answer-Retrieval.

You can input a query or a question. The script then uses semantic search
to find relevant passages in Simple English Wikipedia (as it is smaller and fits better in RAM).

As model, we use: nq-distilbert-base-v1

It was trained on the Natural Questions dataset, a dataset with real questions from Google Search
together with annotated data from Wikipedia providing the answer. For the passages, we encode the
Wikipedia article tile together with the individual text passages.

Google Colab Example: https://colab.research.google.com/drive/11GunvCqJuebfeTlgbJWkIMT0xJH6PWF1?usp=sharing
"""
import json
from sentence_transformers import SentenceTransformer, util
import time
import gzip
import os
import torch
# from .

# We use the Bi-Encoder to encode all passages, so that we can use it with sematic search
model_name = 'nq-distilbert-base-v1'
bi_encoder = SentenceTransformer(model_name)
top_k = 5  # Number of passages we want to retrieve with the bi-encoder

# As dataset, we use Simple English Wikipedia. Compared to the full English wikipedia, it has only
# about 170k articles. We split these articles into paragraphs and encode them with the bi-encoder

wikipedia_filepath = 'data/simplewiki-2020-11-01.jsonl.gz'

if not os.path.exists(wikipedia_filepath):
    util.http_get('http://sbert.net/datasets/simplewiki-2020-11-01.jsonl.gz', wikipedia_filepath)

passages = []
with gzip.open(wikipedia_filepath, 'rt', encoding='utf8') as fIn:
    for line in fIn:
        data = json.loads(line.strip())
        for paragraph in data['paragraphs']:
            # We encode the passages as [title, text]
            passages.append([data['title'], paragraph])
            print(data['title'])
            print(paragraph)
            print("________+________")

# If you like, you can also limit the number of passages you want to use
print("Passages:", len(passages))

# To speed things up, pre-computed embeddings are downloaded.
# The provided file encoded the passages with the model 'nq-distilbert-base-v1'
if model_name == 'nq-distilbert-base-v1':
    embeddings_filepath = 'simplewiki-2020-11-01-nq-distilbert-base-v1.pt'
    if not os.path.exists(embeddings_filepath):
        util.http_get('http://sbert.net/datasets/simplewiki-2020-11-01-nq-distilbert-base-v1.pt', embeddings_filepath)

    corpus_embeddings = torch.load(embeddings_filepath, map_location=torch.device('cpu'))
    corpus_embeddings = corpus_embeddings.float()  # Convert embedding file to float
    if torch.cuda.is_available():
        corpus_embeddings = corpus_embeddings.to('cuda')
    else:
        corpus_embeddings = corpus_embeddings.to('cpu')
else:  # Here, we compute the corpus_embeddings from scratch (which can take a while depending on the GPU)
    corpus_embeddings = bi_encoder.encode(passages, convert_to_tensor=True, show_progress_bar=True)

# while True:
#     query = input("Please enter a question: ")

#     # Encode the query using the bi-encoder and find potentially relevant passages
#     start_time = time.time()
#     question_embedding = bi_encoder.encode(query, convert_to_tensor=True)
#     hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=top_k)
#     hits = hits[0]  # Get the hits for the first query

#     end_time = time.time()

#     # Output of top-k hits
#     print("Input question:", query)
#     print("Results (after {:.3f} seconds):".format(end_time - start_time))
#     for hit in hits:
#         print("\t{:.3f}\t{}".format(hit['score'], passages[hit['corpus_id']]))

#     print("\n\n========\n")

def search(query):

    # Encode the query using the bi-encoder and find potentially relevant passages
    start_time = time.time()
    question_embedding = bi_encoder.encode(query, convert_to_tensor=True)
    hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=top_k)
    hits = hits[0]  # Get the hits for the first query

    end_time = time.time()

    # Output of top-k hits
    print("Input question:", query)
    print("Results (after {:.3f} seconds):".format(end_time - start_time))
    for hit in hits:
        print("\t{:.3f}\t{}".format(hit['score'], passages[hit['corpus_id']]))

    print("\n\n========\n")

def main():
    query = input("Please enter a question: ")
    search(query)

if __name__ == "__main__":

    main()