File size: 2,127 Bytes
75c3a89
 
 
6ae27e8
75c3a89
6ae27e8
75c3a89
6ae27e8
75c3a89
31f3439
75c3a89
6ae27e8
75c3a89
a41bdbc
 
6ae27e8
a41bdbc
6ae27e8
 
 
5cd1ac6
31f3439
5cd1ac6
6ae27e8
 
31f3439
 
fa5d8a4
31f3439
 
 
fa5d8a4
6ae27e8
 
 
 
 
fa5d8a4
a41bdbc
6ae27e8
 
fa5d8a4
75c3a89
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import gzip
import json

import pandas as pd
import numpy as np
import jax.numpy as jnp
import tqdm

from sentence_transformers import util
from typing import List, Union
import torch

from backend.utils import load_model, filter_questions, load_embeddings


def cos_sim(a, b):
    return jnp.matmul(a, jnp.transpose(b)) / (jnp.linalg.norm(a) * jnp.linalg.norm(b))


# We get similarity between embeddings.
def text_similarity(anchor: str, inputs: List[str], model_name: str, model_dict: dict):
    print(model_name)
    model = load_model(model_name, model_dict)

    # Creating embeddings
    if hasattr(model, 'encode'):
        anchor_emb = model.encode(anchor)[None, :]
        inputs_emb = model.encode(inputs)
    else:
        assert len(model) == 2
        anchor_emb = model[0].encode(anchor)[None, :]
        inputs_emb = model[1].encode(inputs)

    # Obtaining similarity
    similarity = list(jnp.squeeze(cos_sim(anchor_emb, inputs_emb)))

    # Returning a Pandas' dataframe
    d = {'inputs': inputs,
         'score': [round(similarity[i], 3) for i in range(len(similarity))]}
    df = pd.DataFrame(d, columns=['inputs', 'score'])

    return df


# Search
def text_search(anchor: str, n_answers: int, model_name: str, model_dict: dict):
    # Proceeding with model
    print(model_name)
    assert model_name == "mpnet_qa"
    model = load_model(model_name, model_dict)

    # Creating embeddings
    query_emb = model.encode(anchor, convert_to_tensor=True)[None, :]

    print("loading embeddings")
    corpus_emb = load_embeddings()

    # Getting hits
    hits = util.semantic_search(query_emb, corpus_emb, score_function=util.dot_score, top_k=n_answers)[0]

    filtered_posts = filter_questions("python")
    print(f"{len(filtered_posts)} posts found with tag: python")

    hits_titles = []
    hits_scores = []
    urls = []
    for hit in hits:
        post = filtered_posts[hit['corpus_id']]
        hits_titles.append(post['title'])
        hits_scores.append("{:.3f}".format(hit['score']))
        urls.append(f"https://stackoverflow.com/q/{post['id']}")

    return hits_titles, hits_scores, urls