|
""" |
|
This examples demonstrates the setup for Question-Answer-Retrieval. |
|
|
|
You can input a query or a question. The script then uses semantic search |
|
to find relevant passages in Simple English Wikipedia (as it is smaller and fits better in RAM). |
|
|
|
As model, we use: nq-distilbert-base-v1 |
|
|
|
It was trained on the Natural Questions dataset, a dataset with real questions from Google Search |
|
together with annotated data from Wikipedia providing the answer. For the passages, we encode the |
|
Wikipedia article tile together with the individual text passages. |
|
|
|
Google Colab Example: https://colab.research.google.com/drive/11GunvCqJuebfeTlgbJWkIMT0xJH6PWF1?usp=sharing |
|
""" |
|
import json |
|
from sentence_transformers import SentenceTransformer, util |
|
import time |
|
import gzip |
|
import os |
|
import torch |
|
|
|
|
|
|
|
model_name = 'nq-distilbert-base-v1' |
|
bi_encoder = SentenceTransformer(model_name) |
|
top_k = 5 |
|
|
|
|
|
|
|
|
|
wikipedia_filepath = 'data/simplewiki-2020-11-01.jsonl.gz' |
|
|
|
if not os.path.exists(wikipedia_filepath): |
|
util.http_get('http://sbert.net/datasets/simplewiki-2020-11-01.jsonl.gz', wikipedia_filepath) |
|
|
|
passages = [] |
|
with gzip.open(wikipedia_filepath, 'rt', encoding='utf8') as fIn: |
|
for line in fIn: |
|
data = json.loads(line.strip()) |
|
for paragraph in data['paragraphs']: |
|
|
|
passages.append([data['title'], paragraph]) |
|
print(data['title']) |
|
print(paragraph) |
|
print("________+________") |
|
|
|
|
|
print("Passages:", len(passages)) |
|
|
|
|
|
|
|
if model_name == 'nq-distilbert-base-v1': |
|
embeddings_filepath = 'simplewiki-2020-11-01-nq-distilbert-base-v1.pt' |
|
if not os.path.exists(embeddings_filepath): |
|
util.http_get('http://sbert.net/datasets/simplewiki-2020-11-01-nq-distilbert-base-v1.pt', embeddings_filepath) |
|
|
|
corpus_embeddings = torch.load(embeddings_filepath, map_location=torch.device('cpu')) |
|
corpus_embeddings = corpus_embeddings.float() |
|
if torch.cuda.is_available(): |
|
corpus_embeddings = corpus_embeddings.to('cuda') |
|
else: |
|
corpus_embeddings = corpus_embeddings.to('cpu') |
|
else: |
|
corpus_embeddings = bi_encoder.encode(passages, convert_to_tensor=True, show_progress_bar=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def search(query): |
|
|
|
|
|
start_time = time.time() |
|
question_embedding = bi_encoder.encode(query, convert_to_tensor=True) |
|
hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=top_k) |
|
hits = hits[0] |
|
|
|
end_time = time.time() |
|
|
|
|
|
print("Input question:", query) |
|
print("Results (after {:.3f} seconds):".format(end_time - start_time)) |
|
for hit in hits: |
|
print("\t{:.3f}\t{}".format(hit['score'], passages[hit['corpus_id']])) |
|
|
|
print("\n\n========\n") |
|
|
|
def main(): |
|
query = input("Please enter a question: ") |
|
search(query) |
|
|
|
if __name__ == "__main__": |
|
|
|
main() |