import gc import logging import os import traceback from typing import Dict from typing import List from typing import Optional from typing import Union import deepspeed import torch from constants import constants from djl_python.inputs import Input from djl_python.outputs import Output from hp_validation import _update_num_beams from hp_validation import _validate_payload from input import process_input from transformers import AutoConfig from transformers import AutoModelForCausalLM from transformers import AutoTokenizer from transformers import PreTrainedModel from transformers import PreTrainedTokenizer TORCH_DTYPE_FROM_STR_MAPPING = {constants.INT8: torch.int8, constants.FP16: torch.float16} class Service(object): """Define Service class for text generation. This class aims to have customized model loading and inference methods. """ def __init__(self): """Initialize model and tokenizer""" self.model: PreTrainedModel = None self.tokenizer: PreTrainedTokenizer = None def load_model(self, properties: dict) -> None: """Load Bloom or BloomZ 176B model, tokenizer from disk to memory as instance attributes. Args: properties (dict): a python dictionary of model parameters. """ tensor_parallel = properties[constants.TENSOR_PARALLEL_DEGREE] deepspeed.init_distributed(constants.NCCL) model_location = properties[constants.MODEL_DIR] if constants.MODEL_ID in properties: model_location = properties[constants.MODEL_ID] logging.info(f"Loading model from disk at '{model_location}'.") curr_pid = os.getpid() logging.info(f"Tensor_parallel={tensor_parallel}::curr_pid={curr_pid}::") tokenizer = AutoTokenizer.from_pretrained(model_location) config = AutoConfig.from_pretrained(model_location) # Construct model with fake meta tensors, later will be replaced during ds-inference ckpt load with deepspeed.OnDevice(dtype=torch.float16, device=constants.META): model = AutoModelForCausalLM.from_config(config, torch_dtype=torch.float16) model = model.eval() torch.cuda.empty_cache() # Deepspeed-Inference Loading repo_root = model_location # tp presharded repos come with their own checkpoints config file checkpoints_json = os.path.join(repo_root, constants.DS_INFERENCE_CONFIG_FILE) model = deepspeed.init_inference( model, tensor_parallel={ constants.TP_SIZE: tensor_parallel, constants.ENABLED: True, constants.MPU: None, constants.TP_GROUP: None, }, base_dir=repo_root, dtype=TORCH_DTYPE_FROM_STR_MAPPING[properties.get("dtype")], max_tokens=constants.MAX_TOKEN_DS_INIT, checkpoint=checkpoints_json, replace_with_kernel_inject=True, ) torch.cuda.empty_cache() deepspeed.runtime.utils.see_memory_usage(constants.POST_DS_INFERENCE_INIT, force=True) model = model.module self.model = model self.tokenizer = tokenizer def inference(self, inputs) -> Union[List[Dict[str, List]], List[List[Dict[str, List]]]]: """Conduct inference based on inputs. Args: inputs (djl_python.inputs.Input): input containing payload and content type. Returns: results (Union[List[Dict[str, List]], List[List[Dict[str, List]]]]): if the length of input string being one, the return output is a list of dictionary, where the length of dictionary corresponds to the number of return sequences; if the length of string being more than one, (i.e., batch inference), the return output is a list of list which inner list contains the one or multiple dictionaries. """ try: input_data, model_kwargs = process_input(inputs=inputs, text_input_for_bloom_model=True) except Exception as e: logging.exception(f"Failed to do inference: {e}; {traceback.format_exc()}") results = Output().error((str(e))) return results input_tokens = self.tokenizer.batch_encode_plus( input_data, return_tensors=constants.RETURN_TENSOR_TYPE, padding=True ) for t in input_tokens: if torch.is_tensor(input_tokens[t]): input_tokens[t] = input_tokens[t].to(torch.cuda.current_device()) outputs = self.model.generate(**input_tokens, **model_kwargs) outputs: List[str] = self.tokenizer.batch_decode(outputs, skip_special_tokens=True) if len(outputs) > 1: results: List[List[Dict[str, List]]] = [] num_return_seq = model_kwargs.get(constants.NUM_RETURN_SEQUENCES, 1) for i in range(0, len(outputs), num_return_seq): res_tmp: List[dict] = [] for j in range(i, i + num_return_seq): res_tmp.append({constants.GENERATED_TEXT: outputs[j]}) results.append(res_tmp) else: results = [{constants.GENERATED_TEXT: outputs}] return results _service = Service() def handle(inputs: Input) -> Optional[Output]: """Define handler method for Bloom 176B model. Args: inputs (djl_python.inputs.Input): input containing payload and content type. Returns: outputs (djl_python.inputs.Output): model prediction output. """ if inputs.is_empty(): # Model server makes an empty call to warmup the model on startup _service.load_model(inputs.get_properties()) return None results = _service.inference(inputs) return Output().add(results)