SentenceTransformer
/
examples
/unsupervised_learning
/query_generation
/1_programming_query_generation.py
""" | |
In this example we train a semantic search model to search through Wikipedia | |
articles about programming articles & technologies. | |
We use the text paragraphs from the following Wikipedia articles: | |
Assembly language, C , C Sharp , C++, Go , Java , JavaScript, Keras, Laravel, MATLAB, Matplotlib, MongoDB, MySQL, Natural Language Toolkit, NumPy, pandas (software), Perl, PHP, PostgreSQL, Python , PyTorch, R , React, Rust , Scala , scikit-learn, SciPy, Swift , TensorFlow, Vue.js | |
In: | |
1_programming_query_generation.py - We generate queries for all paragraphs from these articles | |
2_programming_train_bi-encoder.py - We train a SentenceTransformer bi-encoder with these generated queries. This results in a model we can then use for sematic search (for the given Wikipedia articles). | |
3_programming_semantic_search.py - Shows how the trained model can be used for semantic search | |
""" | |
import json | |
import gzip | |
from transformers import T5Tokenizer, T5ForConditionalGeneration | |
import torch | |
import tqdm | |
import os | |
from sentence_transformers import util | |
paragraphs = set() | |
# We use the Wikipedia articles of certain programming languages | |
corpus_filepath = 'wiki-programmming-20210101.jsonl.gz' | |
if not os.path.exists(corpus_filepath): | |
util.http_get('https://sbert.net/datasets/wiki-programmming-20210101.jsonl.gz', corpus_filepath) | |
with gzip.open(corpus_filepath, 'rt') as fIn: | |
for line in fIn: | |
data = json.loads(line.strip()) | |
for p in data['paragraphs']: | |
if len(p) > 100: #Only take paragraphs with at least 100 chars | |
paragraphs.add(p) | |
paragraphs = list(paragraphs) | |
print("Paragraphs:", len(paragraphs)) | |
# Now we load the model that is able to generate queries given a paragraph. | |
# This model was trained on the MS MARCO dataset, a dataset with 500k | |
# queries from Bing and the respective relevant passage | |
tokenizer = T5Tokenizer.from_pretrained('BeIR/query-gen-msmarco-t5-large-v1') | |
model = T5ForConditionalGeneration.from_pretrained('BeIR/query-gen-msmarco-t5-large-v1') | |
model.eval() | |
#Select the device | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
model.to(device) | |
# Parameters for generation | |
batch_size = 8 #Batch size | |
num_queries = 5 #Number of queries to generate for every paragraph | |
max_length_paragraph = 300 #Max length for paragraph | |
max_length_query = 64 #Max length for output query | |
# Now for every paragraph in our corpus, we generate the queries | |
with open('generated_queries.tsv', 'w') as fOut: | |
for start_idx in tqdm.trange(0, len(paragraphs), batch_size): | |
sub_paragraphs = paragraphs[start_idx:start_idx+batch_size] | |
inputs = tokenizer.prepare_seq2seq_batch(sub_paragraphs, max_length=max_length_paragraph, truncation=True, return_tensors='pt').to(device) | |
outputs = model.generate( | |
**inputs, | |
max_length=max_length_query, | |
do_sample=True, | |
top_p=0.95, | |
num_return_sequences=num_queries) | |
for idx, out in enumerate(outputs): | |
query = tokenizer.decode(out, skip_special_tokens=True) | |
para = sub_paragraphs[int(idx/num_queries)] | |
fOut.write("{}\t{}\n".format(query.replace("\t", " ").strip(), para.replace("\t", " ").strip())) | |