File size: 2,779 Bytes
818f654
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import os
import streamlit as st
import logging
import pandas as pd
import numpy as np
import pickle
from cryptography.fernet import Fernet
from sentence_transformers import SentenceTransformer, util

def load_model():

    no_model = st.session_state.get('model') is None
    swap_model = st.session_state.get('input_model_name') != st.session_state['loaded_model_name']

    if swap_model or no_model:
        
        with st.spinner('Loading the model might take a couple of seconds...'):
            env_local = st.session_state['input_model_name'].lower() + '_path'
            env_remote = st.session_state['input_model_name'].lower() + '_remote_path'
            
            if os.environ.get(env_remote):
                model_path = os.environ.get(env_remote)
            else:
                model_path = os.getenv(env_local)

            auth_token = os.environ.get('read_models') or True

            st.session_state['model'] = SentenceTransformer(
                model_name_or_path=model_path,
                use_auth_token=auth_token
            )

            st.session_state['loaded_model_name'] = st.session_state['input_model_name']

            logging.info(f"Loaded {st.session_state['input_model_name']}!")
        
        with st.spinner('Loading embeddings...'):

            file_path = f"./{st.session_state['input_model_name'].lower()}.enc"

            with open(file_path, 'rb') as f:
                encrypted_data = f.read()

            try:
                cipher = Fernet(st.session_state['decrypt_key'])
                decrypted_df = cipher.decrypt(encrypted_data)
                st.session_state['db'] = pickle.loads(decrypted_df)
                st.session_state['valid_decrypt_key'] = True
            except Exception as e:
                st.error(body="Error: No valid encryption key!", icon="🔑")
                logging.error(e)
                return

                
            # st.session_state['db'] = pd.read_parquet(
            #     path=f"./{st.session_state['input_model_name'].lower()}.parquet"
            # )

            #decrypt_key

def search():
    
    with st.spinner('Searching...'):
        query_embeddings = st.session_state['model'].encode(sentences=st.session_state['search_query']).mean(axis=0)

        query_scores = util.cos_sim(
            a=np.array(query_embeddings),
            b=st.session_state['db']['ItemStemEmbeddings']
        ).squeeze()

        st.session_state['results'] = pd.DataFrame({
            'Match': query_scores,
            'Scale': st.session_state['db']['ScaleName'],
            'Instrument': st.session_state['db']['InstrumentName'],
            'Reference': st.session_state['db']['InstrumentApaReference'],
        }).sort_values(by='Match', ascending=False)