from sentence_transformers import SentenceTransformer import litserve as ls from huggingface_hub import login import os login(token=os.getenv("HF_TOKEN")) DATA_PATH = os.getenv("DATA_PATH") RETRIEVAL_MODEL_NAME = os.getenv("RETRIEVAL_MODEL_NAME") SIMILARITY_MODEL_NAME = os.getenv("SIMILARITY_MODEL_NAME") class EmbeddingModelAPI(ls.LitAPI): def setup(self, device): self.retrieval_model = SentenceTransformer( os.path.join(DATA_PATH, RETRIEVAL_MODEL_NAME), backend="onnx", model_kwargs={"file_name": "onnx/model.onnx"}, trust_remote_code=True, ) self.retrieval_model.max_seq_length = 2048 self.similarity_model = SentenceTransformer( os.path.join(DATA_PATH, SIMILARITY_MODEL_NAME), backend="onnx", model_kwargs={"file_name": "onnx/model.onnx"}, trust_remote_code=True, ) def decode_request(self, request, **kwargs): sentences = request["sentences"] type = request["type"] return sentences, type def predict(self, x, **kwargs): chunks, type = x if type == "default": return self.retrieval_model.encode(chunks).tolist() elif type == "similarity": return self.similarity_model.encode(chunks).tolist() def encode_response(self, output, **kwargs): return {"data": output} if __name__ == "__main__": api = EmbeddingModelAPI() server = ls.LitServer(api) server.run(generate_client_file=False, port=7860)