File size: 3,312 Bytes
d59a442
5e17fcf
 
ca75f47
5e17fcf
d59a442
 
 
ca75f47
 
 
 
 
3b2a25c
 
ca75f47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3b2a25c
d59a442
 
 
 
 
 
 
ca75f47
 
 
 
 
 
3b2a25c
 
 
 
 
ca75f47
 
 
 
 
d59a442
ca75f47
 
 
 
 
d59a442
 
 
ca75f47
d59a442
 
5e17fcf
d59a442
5e17fcf
 
 
ca75f47
5e17fcf
ca75f47
5e17fcf
ca75f47
 
5e17fcf
ca75f47
 
 
 
 
 
 
 
 
 
 
 
d59a442
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import gradio as gr
import ir_datasets
import pandas as pd
import numpy as np

from autogluon.multimodal import MultiModalPredictor


query_embedding = None
document_embedding = None
docs_df = None

def text_embedding_batch():
    global query_embedding
    global docs_df
    model_name = "sentence-transformers/all-MiniLM-L6-v2"
    dataset = ir_datasets.load("beir/fiqa/dev")
    docs_df = pd.DataFrame(dataset.docs_iter()).set_index("doc_id").sample(frac=0.0001)
    predictor = MultiModalPredictor(
        pipeline="feature_extraction",
        hyperparameters={
            "model.hf_text.checkpoint_name": model_name
        }
    )
    embedding = predictor.extract_embedding(docs_df)
    query_embedding = embedding["text"]
    return query_embedding


def text_embedding_single(query: str):
    global document_embedding
    model_name = "sentence-transformers/all-MiniLM-L6-v2"
    predictor = MultiModalPredictor(
        pipeline="feature_extraction",
        hyperparameters={
            "model.hf_text.checkpoint_name": model_name
        }
    )
    embedding = predictor.extract_embedding([query])
    document_embedding = embedding["0"]
    return document_embedding


def rank_document():
    global query_embedding
    global document_embedding
    global docs_df
    print('~~~~~here')
    print('~~~~~~~~', query_embedding, document_embedding)
    q_norm = query_embedding / np.linalg.norm(query_embedding, axis=-1, keepdims=True)
    print(q_norm)
    d_norm = document_embedding / np.linalg.norm(document_embedding, axis=-1, keepdims=True)
    scores = d_norm.dot(q_norm[0])
    print(scores)

    result = []
    for idx in np.argsort(-scores)[:2]:
        result.append(docs_df['text'].iloc[idx])
    return result
    

def main():
    with gr.Blocks(title="OpenSearch Demo") as demo:
        gr.Markdown("# Semantic Search with Autogluon")
        gr.Markdown("Ask an open question!")
        with gr.Row():
            inp_single = gr.Textbox(show_label=False)
        with gr.Row():    
            btn_single = gr.Button("Generate Embedding")
        with gr.Row():
            out_single = gr.DataFrame(label="Embedding", show_label=True)
        gr.Markdown("You can select one of the sample datasets for document embedding")
        with gr.Row():
            btn_fiqa = gr.Button("fiqa")
        with gr.Row():
            out_batch = gr.DataFrame(label="Sample Embeddings", show_label=True, row_count=5)
        gr.Markdown("Now rank the documents and pick the top 3 most relevant from the dataset")
        with gr.Row():    
            btn_rank = gr.Button("Rank documents")
        with gr.Row():
            out_rank = gr.DataFrame(label="Top ranked documents", show_label=True, row_count=5)
        # with gr.Row():
        #     out_batch = gr.File(interactive=True)
        # with gr.Row():    
        #     btn_file = gr.Button("Generate Embedding")
        
        btn_single.click(fn=text_embedding_single, inputs=inp_single, outputs=out_single)
        btn_fiqa.click(fn=text_embedding_batch, inputs=None, outputs=out_batch)
        btn_rank.click(fn=rank_document, inputs=None, outputs=out_rank)
        # btn_file.click(fn=text_embedding_batch, inputs=inp_single, outputs=out_single)    
    demo.launch()   


if __name__ == "__main__":
    main()