File size: 2,779 Bytes
818f654 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 |
import os
import streamlit as st
import logging
import pandas as pd
import numpy as np
import pickle
from cryptography.fernet import Fernet
from sentence_transformers import SentenceTransformer, util
def load_model():
no_model = st.session_state.get('model') is None
swap_model = st.session_state.get('input_model_name') != st.session_state['loaded_model_name']
if swap_model or no_model:
with st.spinner('Loading the model might take a couple of seconds...'):
env_local = st.session_state['input_model_name'].lower() + '_path'
env_remote = st.session_state['input_model_name'].lower() + '_remote_path'
if os.environ.get(env_remote):
model_path = os.environ.get(env_remote)
else:
model_path = os.getenv(env_local)
auth_token = os.environ.get('read_models') or True
st.session_state['model'] = SentenceTransformer(
model_name_or_path=model_path,
use_auth_token=auth_token
)
st.session_state['loaded_model_name'] = st.session_state['input_model_name']
logging.info(f"Loaded {st.session_state['input_model_name']}!")
with st.spinner('Loading embeddings...'):
file_path = f"./{st.session_state['input_model_name'].lower()}.enc"
with open(file_path, 'rb') as f:
encrypted_data = f.read()
try:
cipher = Fernet(st.session_state['decrypt_key'])
decrypted_df = cipher.decrypt(encrypted_data)
st.session_state['db'] = pickle.loads(decrypted_df)
st.session_state['valid_decrypt_key'] = True
except Exception as e:
st.error(body="Error: No valid encryption key!", icon="🔑")
logging.error(e)
return
# st.session_state['db'] = pd.read_parquet(
# path=f"./{st.session_state['input_model_name'].lower()}.parquet"
# )
#decrypt_key
def search():
with st.spinner('Searching...'):
query_embeddings = st.session_state['model'].encode(sentences=st.session_state['search_query']).mean(axis=0)
query_scores = util.cos_sim(
a=np.array(query_embeddings),
b=st.session_state['db']['ItemStemEmbeddings']
).squeeze()
st.session_state['results'] = pd.DataFrame({
'Match': query_scores,
'Scale': st.session_state['db']['ScaleName'],
'Instrument': st.session_state['db']['InstrumentName'],
'Reference': st.session_state['db']['InstrumentApaReference'],
}).sort_values(by='Match', ascending=False) |