|
import os |
|
import streamlit as st |
|
import logging |
|
import pandas as pd |
|
import numpy as np |
|
import pickle |
|
import numpy as np |
|
from bertopic import BERTopic |
|
from umap import UMAP |
|
from cryptography.fernet import Fernet |
|
from sentence_transformers import SentenceTransformer, util |
|
from pdb import set_trace as trace |
|
|
|
|
|
def load_db(): |
|
|
|
with st.spinner('Loading pre-computed embeddings...'): |
|
|
|
file_path = "./surveybot3000.enc" |
|
logging.info(f"Loading data from {file_path}!") |
|
|
|
with open(file_path, 'rb') as f: |
|
encrypted_data = f.read() |
|
|
|
encryption_key = st.session_state['encryption_key'] |
|
|
|
if isinstance(encryption_key, str): |
|
encryption_key = encryption_key.encode('utf-8') |
|
|
|
cipher = Fernet(encryption_key) |
|
decrypted_df = cipher.decrypt(encrypted_data) |
|
st.session_state['db'] = pickle.loads(decrypted_df) |
|
|
|
st.session_state['is_authenticated'] = True |
|
logging.info(f"Loaded {file_path}!") |
|
|
|
|
|
def load_model(): |
|
|
|
with st.spinner('Loading the model...'): |
|
|
|
|
|
model_path = os.environ.get('surveybot3000_remote_path') or os.getenv('surveybot3000_path') |
|
logging.info(f"Loading model from {model_path}!") |
|
|
|
auth_token = os.environ.get('read_models') or True |
|
logging.info(f"Using auth token: {'***' if auth_token else 'None'}") |
|
|
|
st.session_state['model'] = SentenceTransformer( |
|
model_name_or_path=model_path, |
|
token=auth_token, |
|
trust_remote_code=True, |
|
tokenizer_kwargs={ |
|
'use_fast': False, |
|
'legacy': True |
|
} |
|
) |
|
|
|
logging.info(f"Loaded model!") |
|
|
|
|
|
def search(): |
|
|
|
with st.spinner('Searching the synthetic net...'): |
|
query_embeddings = st.session_state['model'].encode( |
|
sentences=st.session_state['search_query']).mean(axis=0) |
|
|
|
item_embeddings = np.vstack( |
|
st.session_state['db']['ItemStemEmbeddings']) |
|
query_scores = util.cos_sim( |
|
a=np.array(query_embeddings), |
|
b=item_embeddings |
|
).squeeze() |
|
|
|
st.session_state['search_results'] = pd.DataFrame({ |
|
'Match': query_scores, |
|
'Scale': st.session_state['db']['ScaleName'], |
|
'Instrument': st.session_state['db']['InstrumentName'], |
|
'Reference': st.session_state['db']['psyctest_doi'], |
|
}).sort_values(by='Match', ascending=False) |