File size: 4,839 Bytes
2359bda |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
"""
This script contains an example how to perform semantic search with ElasticSearch.
As dataset, we use the Quora Duplicate Questions dataset, which contains about 500k questions:
https://www.quora.com/q/quoradata/First-Quora-Dataset-Release-Question-Pairs
Questions are indexed to ElasticSearch together with their respective sentence
embeddings.
The script shows results from BM25 as well as from semantic search with
cosine similarity.
You need ElasticSearch (https://www.elastic.co/de/elasticsearch/) up and running. Further, you need the Python
ElasticSearch Client installed: https://elasticsearch-py.readthedocs.io/en/master/
As embeddings model, we use the SBERT model 'quora-distilbert-multilingual',
that it aligned for 100 languages. I.e., you can type in a question in various languages and it will
return the closest questions in the corpus (questions in the corpus are mainly in English).
"""
from sentence_transformers import SentenceTransformer, util
import os
from elasticsearch import Elasticsearch, helpers
import csv
import time
import tqdm.autonotebook
es = Elasticsearch()
model = SentenceTransformer('quora-distilbert-multilingual')
url = "http://qim.fs.quoracdn.net/quora_duplicate_questions.tsv"
dataset_path = "quora_duplicate_questions.tsv"
max_corpus_size = 100000
#Download dataset if needed
if not os.path.exists(dataset_path):
print("Download dataset")
util.http_get(url, dataset_path)
#Get all unique sentences from the file
all_questions = {}
with open(dataset_path, encoding='utf8') as fIn:
reader = csv.DictReader(fIn, delimiter='\t', quoting=csv.QUOTE_MINIMAL)
for row in reader:
all_questions[row['qid1']] = row['question1']
if len(all_questions) >= max_corpus_size:
break
all_questions[row['qid2']] = row['question2']
if len(all_questions) >= max_corpus_size:
break
qids = list(all_questions.keys())
questions = [all_questions[qid] for qid in qids]
#Index data, if the index does not exists
if not es.indices.exists(index="quora"):
try:
es_index = {
"mappings": {
"properties": {
"question": {
"type": "text"
},
"question_vector": {
"type": "dense_vector",
"dims": 768
}
}
}
}
es.indices.create(index='quora', body=es_index, ignore=[400])
chunk_size = 500
print("Index data (you can stop it by pressing Ctrl+C once):")
with tqdm.tqdm(total=len(qids)) as pbar:
for start_idx in range(0, len(qids), chunk_size):
end_idx = start_idx+chunk_size
embeddings = model.encode(questions[start_idx:end_idx], show_progress_bar=False)
bulk_data = []
for qid, question, embedding in zip(qids[start_idx:end_idx], questions[start_idx:end_idx], embeddings):
bulk_data.append({
"_index": 'quora',
"_id": qid,
"_source": {
"question": question,
"question_vector": embedding
}
})
helpers.bulk(es, bulk_data)
pbar.update(chunk_size)
except:
print("During index an exception occured. Continue\n\n")
#Interactive search queries
while True:
inp_question = input("Please enter a question: ")
encode_start_time = time.time()
question_embedding = model.encode(inp_question)
encode_end_time = time.time()
#Lexical search
bm25 = es.search(index="quora", body={"query": {"match": {"question": inp_question }}})
#Sematic search
sem_search = es.search(index="quora", body={
"query": {
"script_score": {
"query": {
"match_all": {}
},
"script": {
"source": "cosineSimilarity(params.queryVector, doc['question_vector']) + 1.0",
"params": {
"queryVector": question_embedding
}
}
}
}
})
print("Input question:", inp_question)
print("Computing the embedding took {:.3f} seconds, BM25 search took {:.3f} seconds, semantic search with ES took {:.3f} seconds".format(encode_end_time-encode_start_time, bm25['took']/1000, sem_search['took']/1000))
print("BM25 results:")
for hit in bm25['hits']['hits'][0:5]:
print("\t{}".format(hit['_source']['question']))
print("\nSemantic Search results:")
for hit in sem_search['hits']['hits'][0:5]:
print("\t{}".format(hit['_source']['question']))
print("\n\n========\n") |