Spaces:
Sleeping
Sleeping
File size: 2,437 Bytes
818f654 cf1362c 818f654 cf1362c 818f654 cf1362c 818f654 cf1362c 818f654 e113c95 cf1362c 818f654 cf1362c 818f654 89a0138 e113c95 89a0138 cf1362c 818f654 cf1362c 818f654 cf1362c 23d73be cf1362c b6858d4 faa78fa cf1362c faa78fa cf1362c 818f654 e113c95 818f654 cf1362c 818f654 cf1362c 818f654 cf1362c 818f654 cf1362c e113c95 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 |
import os
import streamlit as st
import logging
import pandas as pd
import numpy as np
import pickle
import numpy as np
from bertopic import BERTopic
from umap import UMAP
from cryptography.fernet import Fernet
from sentence_transformers import SentenceTransformer, util
from pdb import set_trace as trace
def load_db():
with st.spinner('Loading pre-computed embeddings...'):
file_path = "./surveybot3000.enc"
logging.info(f"Loading data from {file_path}!")
with open(file_path, 'rb') as f:
encrypted_data = f.read()
encryption_key = st.session_state['encryption_key']
if isinstance(encryption_key, str):
encryption_key = encryption_key.encode('utf-8')
cipher = Fernet(encryption_key)
decrypted_df = cipher.decrypt(encrypted_data)
st.session_state['db'] = pickle.loads(decrypted_df)
st.session_state['is_authenticated'] = True
logging.info(f"Loaded {file_path}!")
def load_model():
with st.spinner('Loading the model...'):
model_path = os.environ.get('surveybot3000_remote_path') or os.getenv('surveybot3000_path')
logging.info(f"Loading model from {model_path}!")
auth_token = os.environ.get('read_models') or True
logging.info(f"Using auth token: {'***' if auth_token else 'None'}")
st.session_state['model'] = SentenceTransformer(
model_name_or_path=model_path,
token=auth_token,
trust_remote_code=True,
tokenizer_kwargs={
'use_fast': False,
'legacy': True
}
)
logging.info(f"Loaded model!")
def search():
with st.spinner('Searching the synthetic net...'):
query_embeddings = st.session_state['model'].encode(
sentences=st.session_state['search_query']).mean(axis=0)
item_embeddings = np.vstack(
st.session_state['db']['ItemStemEmbeddings'])
query_scores = util.cos_sim(
a=np.array(query_embeddings),
b=item_embeddings
).squeeze()
st.session_state['search_results'] = pd.DataFrame({
'Match': query_scores,
'Scale': st.session_state['db']['ScaleName'],
'Instrument': st.session_state['db']['InstrumentName'],
'Reference': st.session_state['db']['psyctest_doi'],
}).sort_values(by='Match', ascending=False) |