Trent
Search function
75c3a89
raw
history blame
2.13 kB
import gzip
import json
import pandas as pd
import numpy as np
import jax.numpy as jnp
import tqdm
from sentence_transformers import util
from typing import List, Union
import torch
from backend.utils import load_model, filter_questions, load_embeddings
def cos_sim(a, b):
return jnp.matmul(a, jnp.transpose(b)) / (jnp.linalg.norm(a) * jnp.linalg.norm(b))
# We get similarity between embeddings.
def text_similarity(anchor: str, inputs: List[str], model_name: str, model_dict: dict):
print(model_name)
model = load_model(model_name, model_dict)
# Creating embeddings
if hasattr(model, 'encode'):
anchor_emb = model.encode(anchor)[None, :]
inputs_emb = model.encode(inputs)
else:
assert len(model) == 2
anchor_emb = model[0].encode(anchor)[None, :]
inputs_emb = model[1].encode(inputs)
# Obtaining similarity
similarity = list(jnp.squeeze(cos_sim(anchor_emb, inputs_emb)))
# Returning a Pandas' dataframe
d = {'inputs': inputs,
'score': [round(similarity[i], 3) for i in range(len(similarity))]}
df = pd.DataFrame(d, columns=['inputs', 'score'])
return df
# Search
def text_search(anchor: str, n_answers: int, model_name: str, model_dict: dict):
# Proceeding with model
print(model_name)
assert model_name == "mpnet_qa"
model = load_model(model_name, model_dict)
# Creating embeddings
query_emb = model.encode(anchor, convert_to_tensor=True)[None, :]
print("loading embeddings")
corpus_emb = load_embeddings()
# Getting hits
hits = util.semantic_search(query_emb, corpus_emb, score_function=util.dot_score, top_k=n_answers)[0]
filtered_posts = filter_questions("python")
print(f"{len(filtered_posts)} posts found with tag: python")
hits_titles = []
hits_scores = []
urls = []
for hit in hits:
post = filtered_posts[hit['corpus_id']]
hits_titles.append(post['title'])
hits_scores.append("{:.3f}".format(hit['score']))
urls.append(f"https://stackoverflow.com/q/{post['id']}")
return hits_titles, hits_scores, urls