import logging from typing import Any from typing import Dict from typing import Optional from typing import Union from djl_python.deepspeed import DeepSpeedService from djl_python.inputs import Input from djl_python.outputs import Output from hp_validation import _update_num_beams from hp_validation import _validate_payload from input import process_input from transformers import set_seed class DeepSpeedServiceTextGeneration(DeepSpeedService): """Define subclass DeepSpeedServiceTextGeneration. This class aims to have customized inference function. For the definition of DeepSpeedService, see https://github.com/deepjavalibrary/djl-serving/blob/master/engines/python/setup/djl_python/deepspeed.py """ def inference(self, inputs: Input) -> Output: """Define customized inference method to have hyperparameter validation for text generation task. Args: inputs (djl_python.inputs.Input): input containing payload and content type. Returns: outputs (djl_python.inputs.Output): model prediction output. """ try: input_data, model_kwargs = process_input(inputs=inputs, text_input_for_bloom_model=False) result = self.pipeline(input_data, **model_kwargs) outputs = Output() outputs.add(result) except Exception as e: logging.exception("Failed to do inference") outputs = Output().error((str(e))) return outputs _service = DeepSpeedServiceTextGeneration() def handle(inputs: Input) -> Optional[Output]: """Define customized inference method to have hyperparameter validation for text generation task Args: inputs (djl_python.inputs.Input): input containing payload and content type. Returns: outputs (djl_python.inputs.Output): model prediction output. """ if not _service.initialized: _service.initialize(inputs.get_properties()) if inputs.is_empty(): return None return _service.inference(inputs)