|
import gc |
|
import logging |
|
import os |
|
import traceback |
|
from typing import Dict |
|
from typing import List |
|
from typing import Optional |
|
from typing import Union |
|
|
|
import deepspeed |
|
import torch |
|
from constants import constants |
|
from djl_python.inputs import Input |
|
from djl_python.outputs import Output |
|
from hp_validation import _update_num_beams |
|
from hp_validation import _validate_payload |
|
from input import process_input |
|
from transformers import AutoConfig |
|
from transformers import AutoModelForCausalLM |
|
from transformers import AutoTokenizer |
|
from transformers import PreTrainedModel |
|
from transformers import PreTrainedTokenizer |
|
|
|
|
|
TORCH_DTYPE_FROM_STR_MAPPING = {constants.INT8: torch.int8, constants.FP16: torch.float16} |
|
|
|
|
|
class Service(object): |
|
"""Define Service class for text generation. |
|
|
|
This class aims to have customized model loading and inference methods. |
|
""" |
|
|
|
def __init__(self): |
|
"""Initialize model and tokenizer""" |
|
self.model: PreTrainedModel = None |
|
self.tokenizer: PreTrainedTokenizer = None |
|
|
|
def load_model(self, properties: dict) -> None: |
|
"""Load Bloom or BloomZ 176B model, tokenizer from disk to memory as instance attributes. |
|
|
|
Args: |
|
properties (dict): a python dictionary of model parameters. |
|
""" |
|
tensor_parallel = properties[constants.TENSOR_PARALLEL_DEGREE] |
|
deepspeed.init_distributed(constants.NCCL) |
|
|
|
model_location = properties[constants.MODEL_DIR] |
|
if constants.MODEL_ID in properties: |
|
model_location = properties[constants.MODEL_ID] |
|
logging.info(f"Loading model from disk at '{model_location}'.") |
|
|
|
curr_pid = os.getpid() |
|
logging.info(f"Tensor_parallel={tensor_parallel}::curr_pid={curr_pid}::") |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_location) |
|
config = AutoConfig.from_pretrained(model_location) |
|
|
|
with deepspeed.OnDevice(dtype=torch.float16, device=constants.META): |
|
model = AutoModelForCausalLM.from_config(config, torch_dtype=torch.float16) |
|
model = model.eval() |
|
torch.cuda.empty_cache() |
|
|
|
|
|
repo_root = model_location |
|
|
|
checkpoints_json = os.path.join(repo_root, constants.DS_INFERENCE_CONFIG_FILE) |
|
model = deepspeed.init_inference( |
|
model, |
|
tensor_parallel={ |
|
constants.TP_SIZE: tensor_parallel, |
|
constants.ENABLED: True, |
|
constants.MPU: None, |
|
constants.TP_GROUP: None, |
|
}, |
|
base_dir=repo_root, |
|
dtype=TORCH_DTYPE_FROM_STR_MAPPING[properties.get("dtype")], |
|
max_tokens=constants.MAX_TOKEN_DS_INIT, |
|
checkpoint=checkpoints_json, |
|
replace_with_kernel_inject=True, |
|
) |
|
torch.cuda.empty_cache() |
|
deepspeed.runtime.utils.see_memory_usage(constants.POST_DS_INFERENCE_INIT, force=True) |
|
model = model.module |
|
self.model = model |
|
self.tokenizer = tokenizer |
|
|
|
def inference(self, inputs) -> Union[List[Dict[str, List]], List[List[Dict[str, List]]]]: |
|
"""Conduct inference based on inputs. |
|
|
|
Args: |
|
inputs (djl_python.inputs.Input): input containing payload and content type. |
|
Returns: |
|
results (Union[List[Dict[str, List]], List[List[Dict[str, List]]]]): if the length of input string being |
|
one, the return output is a list of dictionary, where the length of dictionary corresponds to the number of |
|
return sequences; if the length of string being more than one, (i.e., batch inference), the return output is |
|
a list of list which inner list contains the one or multiple dictionaries. |
|
""" |
|
try: |
|
input_data, model_kwargs = process_input(inputs=inputs, text_input_for_bloom_model=True) |
|
except Exception as e: |
|
logging.exception(f"Failed to do inference: {e}; {traceback.format_exc()}") |
|
results = Output().error((str(e))) |
|
return results |
|
|
|
input_tokens = self.tokenizer.batch_encode_plus( |
|
input_data, return_tensors=constants.RETURN_TENSOR_TYPE, padding=True |
|
) |
|
for t in input_tokens: |
|
if torch.is_tensor(input_tokens[t]): |
|
input_tokens[t] = input_tokens[t].to(torch.cuda.current_device()) |
|
outputs = self.model.generate(**input_tokens, **model_kwargs) |
|
outputs: List[str] = self.tokenizer.batch_decode(outputs, skip_special_tokens=True) |
|
|
|
if len(outputs) > 1: |
|
results: List[List[Dict[str, List]]] = [] |
|
num_return_seq = model_kwargs.get(constants.NUM_RETURN_SEQUENCES, 1) |
|
for i in range(0, len(outputs), num_return_seq): |
|
res_tmp: List[dict] = [] |
|
for j in range(i, i + num_return_seq): |
|
res_tmp.append({constants.GENERATED_TEXT: outputs[j]}) |
|
results.append(res_tmp) |
|
else: |
|
results = [{constants.GENERATED_TEXT: outputs}] |
|
return results |
|
|
|
|
|
_service = Service() |
|
|
|
|
|
def handle(inputs: Input) -> Optional[Output]: |
|
"""Define handler method for Bloom 176B model. |
|
|
|
Args: |
|
inputs (djl_python.inputs.Input): input containing payload and content type. |
|
Returns: |
|
outputs (djl_python.inputs.Output): model prediction output. |
|
""" |
|
|
|
if inputs.is_empty(): |
|
|
|
_service.load_model(inputs.get_properties()) |
|
return None |
|
|
|
results = _service.inference(inputs) |
|
return Output().add(results) |
|
|