redPajama-3b-zAgile-base / inference_mpt.py
dtorres-zAgile's picture
Upload 25 files
0be3778
import os
from pathlib import Path
from typing import Any
from typing import Dict
import torch
from constants import constants
from djl_python.inputs import Input
from djl_python.outputs import Output
from sagemaker_jumpstart_huggingface_script_utilities.djl_python.dtypes import get_torch_dtype_from_str
from sagemaker_jumpstart_huggingface_script_utilities.djl_python.handle import create_handle
from sagemaker_jumpstart_huggingface_script_utilities.djl_python.inference.textgeneration import format_djl_output
from sagemaker_jumpstart_huggingface_script_utilities.djl_python.inference.textgeneration import generate_text
from sagemaker_jumpstart_huggingface_script_utilities.djl_python.inference.textgeneration import (
model_output_to_batch_output,
)
from sagemaker_jumpstart_huggingface_script_utilities.djl_python.inference.textgeneration import process_input
from sagemaker_jumpstart_huggingface_script_utilities.payload.stopping_criteria import (
add_stopping_criteria_to_model_kwargs,
)
from sagemaker_jumpstart_script_utilities.subprocess import run_with_error_handling
from transformers import AutoConfig
from transformers import AutoModelForCausalLM
from transformers import AutoTokenizer
FLASH_ATTENTION_WHEEL_FILENAME = "flash_attn-1.0.5-cp39-cp39-linux_x86_64.whl"
MAXIMUM_INPUT_SEQUENCE_LENGTH = "MAXIMUM_INPUT_SEQUENCE_LENGTH"
DTYPE = "dtype"
MAX_SEQ_LEN = "max_seq_len"
ATTN_IMPL = "attn_impl"
TRITON = "triton"
LIB = "lib"
EXTRA_DEPENDENCIES = "extra_dependencies"
EXTRA_DEPENDENCIES_PATH = Path(__file__).parent / LIB / EXTRA_DEPENDENCIES
# As of the time of this script creation, the maintainers of the flash-attn package do not distribute a bdist. To
# avoid an approximately 10 minute source build, a pre-built wheel is hosted as an extra dependency.
run_with_error_handling(["pip", "install", str(EXTRA_DEPENDENCIES_PATH / FLASH_ATTENTION_WHEEL_FILENAME)])
class MptPythonServiceTextGeneration:
"""A service object for the MPT model family using the DJL Python engine.
This set of MPT models is not compatible with the T4 GPUs on the G4 instance family because some Triton kernels are
not compatible with architectures older than Ampere
([reference](https://github.com/microsoft/DeepSpeed-MII/issues/170#issuecomment-1526277566)).
This set of MPT models uses the [FlashAttention](https://arxiv.org/pdf/2205.14135.pdf) mechanism with implementation
that requires the Python environment to include the following libraries:
- [flash-attn](https://github.com/HazyResearch/flash-attention)
- [triton](https://github.com/openai/triton)
- [einops](https://github.com/arogozhnikov/einops)
"""
def __init__(self) -> None:
"""Set initialization flag to False, model to be initialized upon invocation of initialize method."""
self.initialized = False
def initialize(self, properties: Dict[str, Any]) -> None:
"""Initialize the MPT model and tokenizer.
This model contains custom Python scripts that are not yet included in the transformers library. Therefore, the
non-standard `trust_remote_code=True` must be set when loading config and the model. The referenced remote code
exists within the model artifact, which has been manually inspected for malicious code, downloaded with commit
revision hash checking, and prepackaged within the script tarball. Therefore, the "remote" code is hosted on S3
via the SageMaker JumpStart deployment pipeline. This script will not perform requests to the public internet.
Due to these steps and the fact that SageMaker JumpStart deploys models with network isolation enabled by
default, these scripts are deemed safe to enable to `trust_remote_code=True` flag.
Additionally, the accelerate library currently does not support the MPT model type.
"""
model_dir = properties[constants.MODEL_DIR]
# View docstring for details on this `trust_remote_code=True` flag.
config = AutoConfig.from_pretrained(model_dir, trust_remote_code=True)
config.attn_config[ATTN_IMPL] = TRITON
maximum_sequence_length = os.environ.get(MAXIMUM_INPUT_SEQUENCE_LENGTH)
if maximum_sequence_length is not None:
config.update({MAX_SEQ_LEN: int(maximum_sequence_length)})
torch_dtype = get_torch_dtype_from_str(properties.get(DTYPE))
self.model = AutoModelForCausalLM.from_pretrained(
model_dir, config=config, torch_dtype=torch_dtype, trust_remote_code=True
)
self.model.to(device="cuda:0")
self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
self.tokenizer.pad_token = self.tokenizer.eos_token
self.initialized = True
@format_djl_output
def inference(self, inputs: Input) -> Output:
"""Perform inference for text generation task."""
input_data, model_kwargs = process_input(inputs, input_data_as_list=False)
model_kwargs = add_stopping_criteria_to_model_kwargs(model_kwargs, self.tokenizer)
model_output = generate_text(self.model, self.tokenizer, input_data, model_kwargs)
return model_output_to_batch_output(model_output, model_kwargs)
_service = MptPythonServiceTextGeneration()
handle = create_handle(_service)