redPajama-3b-zAgile-base / inference.py
dtorres-zAgile's picture
Upload 25 files
0be3778
import logging
from typing import Any
from typing import Dict
from typing import Optional
from typing import Union
from djl_python.deepspeed import DeepSpeedService
from djl_python.inputs import Input
from djl_python.outputs import Output
from hp_validation import _update_num_beams
from hp_validation import _validate_payload
from input import process_input
from transformers import set_seed
class DeepSpeedServiceTextGeneration(DeepSpeedService):
"""Define subclass DeepSpeedServiceTextGeneration.
This class aims to have customized inference function.
For the definition of DeepSpeedService, see
https://github.com/deepjavalibrary/djl-serving/blob/master/engines/python/setup/djl_python/deepspeed.py
"""
def inference(self, inputs: Input) -> Output:
"""Define customized inference method to have hyperparameter validation for text generation task.
Args:
inputs (djl_python.inputs.Input): input containing payload and content type.
Returns:
outputs (djl_python.inputs.Output): model prediction output.
"""
try:
input_data, model_kwargs = process_input(inputs=inputs, text_input_for_bloom_model=False)
result = self.pipeline(input_data, **model_kwargs)
outputs = Output()
outputs.add(result)
except Exception as e:
logging.exception("Failed to do inference")
outputs = Output().error((str(e)))
return outputs
_service = DeepSpeedServiceTextGeneration()
def handle(inputs: Input) -> Optional[Output]:
"""Define customized inference method to have hyperparameter validation for text generation task
Args:
inputs (djl_python.inputs.Input): input containing payload and content type.
Returns:
outputs (djl_python.inputs.Output): model prediction output.
"""
if not _service.initialized:
_service.initialize(inputs.get_properties())
if inputs.is_empty():
return None
return _service.inference(inputs)