Spaces:
Runtime error
Runtime error
import gzip | |
import json | |
import pandas as pd | |
import numpy as np | |
import jax.numpy as jnp | |
import tqdm | |
from sentence_transformers import util | |
from typing import List, Union | |
import torch | |
from backend.utils import load_model, filter_questions, load_embeddings | |
def cos_sim(a, b): | |
return jnp.matmul(a, jnp.transpose(b)) / (jnp.linalg.norm(a) * jnp.linalg.norm(b)) | |
# We get similarity between embeddings. | |
def text_similarity(anchor: str, inputs: List[str], model_name: str, model_dict: dict): | |
print(model_name) | |
model = load_model(model_name, model_dict) | |
# Creating embeddings | |
if hasattr(model, 'encode'): | |
anchor_emb = model.encode(anchor)[None, :] | |
inputs_emb = model.encode(inputs) | |
else: | |
assert len(model) == 2 | |
anchor_emb = model[0].encode(anchor)[None, :] | |
inputs_emb = model[1].encode(inputs) | |
# Obtaining similarity | |
similarity = list(jnp.squeeze(cos_sim(anchor_emb, inputs_emb))) | |
# Returning a Pandas' dataframe | |
d = {'inputs': inputs, | |
'score': [round(similarity[i], 3) for i in range(len(similarity))]} | |
df = pd.DataFrame(d, columns=['inputs', 'score']) | |
return df | |
# Search | |
def text_search(anchor: str, n_answers: int, model_name: str, model_dict: dict): | |
# Proceeding with model | |
print(model_name) | |
assert model_name == "mpnet_qa" | |
model = load_model(model_name, model_dict) | |
# Creating embeddings | |
query_emb = model.encode(anchor, convert_to_tensor=True)[None, :] | |
print("loading embeddings") | |
corpus_emb = load_embeddings() | |
# Getting hits | |
hits = util.semantic_search(query_emb, corpus_emb, score_function=util.dot_score, top_k=n_answers)[0] | |
filtered_posts = filter_questions("python") | |
print(f"{len(filtered_posts)} posts found with tag: python") | |
hits_titles = [] | |
hits_scores = [] | |
urls = [] | |
for hit in hits: | |
post = filtered_posts[hit['corpus_id']] | |
hits_titles.append(post['title']) | |
hits_scores.append("{:.3f}".format(hit['score'])) | |
urls.append(f"https://stackoverflow.com/q/{post['id']}") | |
return hits_titles, hits_scores, urls | |