embedding / app /server.py
nam pham
feat: fix max sequence length
36e09ce
from sentence_transformers import SentenceTransformer
import litserve as ls
from huggingface_hub import login
import os
login(token=os.getenv("HF_TOKEN"))
DATA_PATH = os.getenv("DATA_PATH")
RETRIEVAL_MODEL_NAME = os.getenv("RETRIEVAL_MODEL_NAME")
SIMILARITY_MODEL_NAME = os.getenv("SIMILARITY_MODEL_NAME")
class EmbeddingModelAPI(ls.LitAPI):
def setup(self, device):
self.retrieval_model = SentenceTransformer(
os.path.join(DATA_PATH, RETRIEVAL_MODEL_NAME),
backend="onnx",
model_kwargs={"file_name": "onnx/model.onnx"},
trust_remote_code=True,
)
self.retrieval_model.max_seq_length = 2048
self.similarity_model = SentenceTransformer(
os.path.join(DATA_PATH, SIMILARITY_MODEL_NAME),
backend="onnx",
model_kwargs={"file_name": "onnx/model.onnx"},
trust_remote_code=True,
)
def decode_request(self, request, **kwargs):
sentences = request["sentences"]
type = request["type"]
return sentences, type
def predict(self, x, **kwargs):
chunks, type = x
if type == "default":
return self.retrieval_model.encode(chunks).tolist()
elif type == "similarity":
return self.similarity_model.encode(chunks).tolist()
def encode_response(self, output, **kwargs):
return {"data": output}
if __name__ == "__main__":
api = EmbeddingModelAPI()
server = ls.LitServer(api)
server.run(generate_client_file=False, port=7860)