Trent
List model loading support
31f3439
raw
history blame
484 Bytes
import streamlit as st
from sentence_transformers import SentenceTransformer
from .config import MODELS_ID
@st.cache(allow_output_mutation=True)
def load_model(model_name):
assert model_name in MODELS_ID.keys()
# Lazy downloading
model_ids = MODELS_ID[model_name]
if type(model_ids) == str:
output = SentenceTransformer(model_ids)
elif hasattr(model_ids, '__iter__'):
output = [SentenceTransformer(name) for name in model_ids]
return output