|
import logging |
|
from typing import Any |
|
from typing import Dict |
|
from typing import Optional |
|
from typing import Union |
|
|
|
from djl_python.deepspeed import DeepSpeedService |
|
from djl_python.inputs import Input |
|
from djl_python.outputs import Output |
|
from hp_validation import _update_num_beams |
|
from hp_validation import _validate_payload |
|
from input import process_input |
|
from transformers import set_seed |
|
|
|
|
|
class DeepSpeedServiceTextGeneration(DeepSpeedService): |
|
"""Define subclass DeepSpeedServiceTextGeneration. |
|
|
|
This class aims to have customized inference function. |
|
For the definition of DeepSpeedService, see |
|
https://github.com/deepjavalibrary/djl-serving/blob/master/engines/python/setup/djl_python/deepspeed.py |
|
""" |
|
|
|
def inference(self, inputs: Input) -> Output: |
|
"""Define customized inference method to have hyperparameter validation for text generation task. |
|
|
|
Args: |
|
inputs (djl_python.inputs.Input): input containing payload and content type. |
|
Returns: |
|
outputs (djl_python.inputs.Output): model prediction output. |
|
""" |
|
|
|
try: |
|
input_data, model_kwargs = process_input(inputs=inputs, text_input_for_bloom_model=False) |
|
result = self.pipeline(input_data, **model_kwargs) |
|
outputs = Output() |
|
outputs.add(result) |
|
except Exception as e: |
|
logging.exception("Failed to do inference") |
|
outputs = Output().error((str(e))) |
|
return outputs |
|
|
|
|
|
_service = DeepSpeedServiceTextGeneration() |
|
|
|
|
|
def handle(inputs: Input) -> Optional[Output]: |
|
"""Define customized inference method to have hyperparameter validation for text generation task |
|
|
|
Args: |
|
inputs (djl_python.inputs.Input): input containing payload and content type. |
|
Returns: |
|
outputs (djl_python.inputs.Output): model prediction output. |
|
""" |
|
|
|
if not _service.initialized: |
|
_service.initialize(inputs.get_properties()) |
|
|
|
if inputs.is_empty(): |
|
return None |
|
|
|
return _service.inference(inputs) |
|
|