import gzip import json import pandas as pd import numpy as np import jax.numpy as jnp import tqdm from sentence_transformers import util from typing import List, Union import torch from backend.utils import load_model, filter_questions, load_embeddings def cos_sim(a, b): return jnp.matmul(a, jnp.transpose(b)) / (jnp.linalg.norm(a) * jnp.linalg.norm(b)) # We get similarity between embeddings. def text_similarity(anchor: str, inputs: List[str], model_name: str, model_dict: dict): print(model_name) model = load_model(model_name, model_dict) # Creating embeddings if hasattr(model, 'encode'): anchor_emb = model.encode(anchor)[None, :] inputs_emb = model.encode(inputs) else: assert len(model) == 2 anchor_emb = model[0].encode(anchor)[None, :] inputs_emb = model[1].encode(inputs) # Obtaining similarity similarity = list(jnp.squeeze(cos_sim(anchor_emb, inputs_emb))) # Returning a Pandas' dataframe d = {'inputs': inputs, 'score': [round(similarity[i], 3) for i in range(len(similarity))]} df = pd.DataFrame(d, columns=['inputs', 'score']) return df # Search def text_search(anchor: str, n_answers: int, model_name: str, model_dict: dict): # Proceeding with model print(model_name) assert model_name == "mpnet_qa" model = load_model(model_name, model_dict) # Creating embeddings query_emb = model.encode(anchor, convert_to_tensor=True)[None, :] print("loading embeddings") corpus_emb = load_embeddings() # Getting hits hits = util.semantic_search(query_emb, corpus_emb, score_function=util.dot_score, top_k=n_answers)[0] filtered_posts = filter_questions("python") print(f"{len(filtered_posts)} posts found with tag: python") hits_titles = [] hits_scores = [] urls = [] for hit in hits: post = filtered_posts[hit['corpus_id']] hits_titles.append(post['title']) hits_scores.append("{:.3f}".format(hit['score'])) urls.append(f"https://stackoverflow.com/q/{post['id']}") return hits_titles, hits_scores, urls