Spaces:
Running
Running
from sentence_transformers import SentenceTransformer | |
import litserve as ls | |
from huggingface_hub import login | |
import os | |
login(token=os.getenv("HF_TOKEN")) | |
DATA_PATH = os.getenv("DATA_PATH") | |
RETRIEVAL_MODEL_NAME = os.getenv("RETRIEVAL_MODEL_NAME") | |
SIMILARITY_MODEL_NAME = os.getenv("SIMILARITY_MODEL_NAME") | |
class EmbeddingModelAPI(ls.LitAPI): | |
def setup(self, device): | |
self.retrieval_model = SentenceTransformer( | |
os.path.join(DATA_PATH, RETRIEVAL_MODEL_NAME), | |
backend="onnx", | |
model_kwargs={"file_name": "onnx/model.onnx"}, | |
trust_remote_code=True, | |
) | |
self.retrieval_model.max_seq_length = 2048 | |
self.similarity_model = SentenceTransformer( | |
os.path.join(DATA_PATH, SIMILARITY_MODEL_NAME), | |
backend="onnx", | |
model_kwargs={"file_name": "onnx/model.onnx"}, | |
trust_remote_code=True, | |
) | |
def decode_request(self, request, **kwargs): | |
sentences = request["sentences"] | |
type = request["type"] | |
return sentences, type | |
def predict(self, x, **kwargs): | |
chunks, type = x | |
if type == "default": | |
return self.retrieval_model.encode(chunks).tolist() | |
elif type == "similarity": | |
return self.similarity_model.encode(chunks).tolist() | |
def encode_response(self, output, **kwargs): | |
return {"data": output} | |
if __name__ == "__main__": | |
api = EmbeddingModelAPI() | |
server = ls.LitServer(api) | |
server.run(generate_client_file=False, port=7860) | |