synth-net / modeling.py
bjorn-hommel's picture
fix attempt regarding breaking change in transformers, bypassing chat template update
faa78fa
import os
import streamlit as st
import logging
import pandas as pd
import numpy as np
import pickle
import numpy as np
from bertopic import BERTopic
from umap import UMAP
from cryptography.fernet import Fernet
from sentence_transformers import SentenceTransformer, util
from pdb import set_trace as trace
def load_db():
with st.spinner('Loading pre-computed embeddings...'):
file_path = "./surveybot3000.enc"
logging.info(f"Loading data from {file_path}!")
with open(file_path, 'rb') as f:
encrypted_data = f.read()
encryption_key = st.session_state['encryption_key']
if isinstance(encryption_key, str):
encryption_key = encryption_key.encode('utf-8')
cipher = Fernet(encryption_key)
decrypted_df = cipher.decrypt(encrypted_data)
st.session_state['db'] = pickle.loads(decrypted_df)
st.session_state['is_authenticated'] = True
logging.info(f"Loaded {file_path}!")
def load_model():
with st.spinner('Loading the model...'):
model_path = os.environ.get('surveybot3000_remote_path') or os.getenv('surveybot3000_path')
logging.info(f"Loading model from {model_path}!")
auth_token = os.environ.get('read_models') or True
logging.info(f"Using auth token: {'***' if auth_token else 'None'}")
st.session_state['model'] = SentenceTransformer(
model_name_or_path=model_path,
token=auth_token,
trust_remote_code=True,
tokenizer_kwargs={
'use_fast': False,
'legacy': True
}
)
logging.info(f"Loaded model!")
def search():
with st.spinner('Searching the synthetic net...'):
query_embeddings = st.session_state['model'].encode(
sentences=st.session_state['search_query']).mean(axis=0)
item_embeddings = np.vstack(
st.session_state['db']['ItemStemEmbeddings'])
query_scores = util.cos_sim(
a=np.array(query_embeddings),
b=item_embeddings
).squeeze()
st.session_state['search_results'] = pd.DataFrame({
'Match': query_scores,
'Scale': st.session_state['db']['ScaleName'],
'Instrument': st.session_state['db']['InstrumentName'],
'Reference': st.session_state['db']['psyctest_doi'],
}).sort_values(by='Match', ascending=False)