redPajama-3b-zAgile-base / bloom_176b.py
dtorres-zAgile's picture
Upload 25 files
0be3778
import gc
import logging
import os
import traceback
from typing import Dict
from typing import List
from typing import Optional
from typing import Union
import deepspeed
import torch
from constants import constants
from djl_python.inputs import Input
from djl_python.outputs import Output
from hp_validation import _update_num_beams
from hp_validation import _validate_payload
from input import process_input
from transformers import AutoConfig
from transformers import AutoModelForCausalLM
from transformers import AutoTokenizer
from transformers import PreTrainedModel
from transformers import PreTrainedTokenizer
TORCH_DTYPE_FROM_STR_MAPPING = {constants.INT8: torch.int8, constants.FP16: torch.float16}
class Service(object):
"""Define Service class for text generation.
This class aims to have customized model loading and inference methods.
"""
def __init__(self):
"""Initialize model and tokenizer"""
self.model: PreTrainedModel = None
self.tokenizer: PreTrainedTokenizer = None
def load_model(self, properties: dict) -> None:
"""Load Bloom or BloomZ 176B model, tokenizer from disk to memory as instance attributes.
Args:
properties (dict): a python dictionary of model parameters.
"""
tensor_parallel = properties[constants.TENSOR_PARALLEL_DEGREE]
deepspeed.init_distributed(constants.NCCL)
model_location = properties[constants.MODEL_DIR]
if constants.MODEL_ID in properties:
model_location = properties[constants.MODEL_ID]
logging.info(f"Loading model from disk at '{model_location}'.")
curr_pid = os.getpid()
logging.info(f"Tensor_parallel={tensor_parallel}::curr_pid={curr_pid}::")
tokenizer = AutoTokenizer.from_pretrained(model_location)
config = AutoConfig.from_pretrained(model_location)
# Construct model with fake meta tensors, later will be replaced during ds-inference ckpt load
with deepspeed.OnDevice(dtype=torch.float16, device=constants.META):
model = AutoModelForCausalLM.from_config(config, torch_dtype=torch.float16)
model = model.eval()
torch.cuda.empty_cache()
# Deepspeed-Inference Loading
repo_root = model_location
# tp presharded repos come with their own checkpoints config file
checkpoints_json = os.path.join(repo_root, constants.DS_INFERENCE_CONFIG_FILE)
model = deepspeed.init_inference(
model,
tensor_parallel={
constants.TP_SIZE: tensor_parallel,
constants.ENABLED: True,
constants.MPU: None,
constants.TP_GROUP: None,
},
base_dir=repo_root,
dtype=TORCH_DTYPE_FROM_STR_MAPPING[properties.get("dtype")],
max_tokens=constants.MAX_TOKEN_DS_INIT,
checkpoint=checkpoints_json,
replace_with_kernel_inject=True,
)
torch.cuda.empty_cache()
deepspeed.runtime.utils.see_memory_usage(constants.POST_DS_INFERENCE_INIT, force=True)
model = model.module
self.model = model
self.tokenizer = tokenizer
def inference(self, inputs) -> Union[List[Dict[str, List]], List[List[Dict[str, List]]]]:
"""Conduct inference based on inputs.
Args:
inputs (djl_python.inputs.Input): input containing payload and content type.
Returns:
results (Union[List[Dict[str, List]], List[List[Dict[str, List]]]]): if the length of input string being
one, the return output is a list of dictionary, where the length of dictionary corresponds to the number of
return sequences; if the length of string being more than one, (i.e., batch inference), the return output is
a list of list which inner list contains the one or multiple dictionaries.
"""
try:
input_data, model_kwargs = process_input(inputs=inputs, text_input_for_bloom_model=True)
except Exception as e:
logging.exception(f"Failed to do inference: {e}; {traceback.format_exc()}")
results = Output().error((str(e)))
return results
input_tokens = self.tokenizer.batch_encode_plus(
input_data, return_tensors=constants.RETURN_TENSOR_TYPE, padding=True
)
for t in input_tokens:
if torch.is_tensor(input_tokens[t]):
input_tokens[t] = input_tokens[t].to(torch.cuda.current_device())
outputs = self.model.generate(**input_tokens, **model_kwargs)
outputs: List[str] = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
if len(outputs) > 1:
results: List[List[Dict[str, List]]] = []
num_return_seq = model_kwargs.get(constants.NUM_RETURN_SEQUENCES, 1)
for i in range(0, len(outputs), num_return_seq):
res_tmp: List[dict] = []
for j in range(i, i + num_return_seq):
res_tmp.append({constants.GENERATED_TEXT: outputs[j]})
results.append(res_tmp)
else:
results = [{constants.GENERATED_TEXT: outputs}]
return results
_service = Service()
def handle(inputs: Input) -> Optional[Output]:
"""Define handler method for Bloom 176B model.
Args:
inputs (djl_python.inputs.Input): input containing payload and content type.
Returns:
outputs (djl_python.inputs.Output): model prediction output.
"""
if inputs.is_empty():
# Model server makes an empty call to warmup the model on startup
_service.load_model(inputs.get_properties())
return None
results = _service.inference(inputs)
return Output().add(results)