|
import os |
|
from pathlib import Path |
|
from typing import Any |
|
from typing import Dict |
|
|
|
import torch |
|
from constants import constants |
|
from djl_python.inputs import Input |
|
from djl_python.outputs import Output |
|
from sagemaker_jumpstart_huggingface_script_utilities.djl_python.dtypes import get_torch_dtype_from_str |
|
from sagemaker_jumpstart_huggingface_script_utilities.djl_python.handle import create_handle |
|
from sagemaker_jumpstart_huggingface_script_utilities.djl_python.inference.textgeneration import format_djl_output |
|
from sagemaker_jumpstart_huggingface_script_utilities.djl_python.inference.textgeneration import generate_text |
|
from sagemaker_jumpstart_huggingface_script_utilities.djl_python.inference.textgeneration import ( |
|
model_output_to_batch_output, |
|
) |
|
from sagemaker_jumpstart_huggingface_script_utilities.djl_python.inference.textgeneration import process_input |
|
from sagemaker_jumpstart_huggingface_script_utilities.payload.stopping_criteria import ( |
|
add_stopping_criteria_to_model_kwargs, |
|
) |
|
from sagemaker_jumpstart_script_utilities.subprocess import run_with_error_handling |
|
from transformers import AutoConfig |
|
from transformers import AutoModelForCausalLM |
|
from transformers import AutoTokenizer |
|
|
|
|
|
FLASH_ATTENTION_WHEEL_FILENAME = "flash_attn-1.0.5-cp39-cp39-linux_x86_64.whl" |
|
MAXIMUM_INPUT_SEQUENCE_LENGTH = "MAXIMUM_INPUT_SEQUENCE_LENGTH" |
|
DTYPE = "dtype" |
|
MAX_SEQ_LEN = "max_seq_len" |
|
ATTN_IMPL = "attn_impl" |
|
TRITON = "triton" |
|
LIB = "lib" |
|
EXTRA_DEPENDENCIES = "extra_dependencies" |
|
EXTRA_DEPENDENCIES_PATH = Path(__file__).parent / LIB / EXTRA_DEPENDENCIES |
|
|
|
|
|
|
|
run_with_error_handling(["pip", "install", str(EXTRA_DEPENDENCIES_PATH / FLASH_ATTENTION_WHEEL_FILENAME)]) |
|
|
|
|
|
class MptPythonServiceTextGeneration: |
|
"""A service object for the MPT model family using the DJL Python engine. |
|
|
|
This set of MPT models is not compatible with the T4 GPUs on the G4 instance family because some Triton kernels are |
|
not compatible with architectures older than Ampere |
|
([reference](https://github.com/microsoft/DeepSpeed-MII/issues/170#issuecomment-1526277566)). |
|
|
|
This set of MPT models uses the [FlashAttention](https://arxiv.org/pdf/2205.14135.pdf) mechanism with implementation |
|
that requires the Python environment to include the following libraries: |
|
- [flash-attn](https://github.com/HazyResearch/flash-attention) |
|
- [triton](https://github.com/openai/triton) |
|
- [einops](https://github.com/arogozhnikov/einops) |
|
""" |
|
|
|
def __init__(self) -> None: |
|
"""Set initialization flag to False, model to be initialized upon invocation of initialize method.""" |
|
self.initialized = False |
|
|
|
def initialize(self, properties: Dict[str, Any]) -> None: |
|
"""Initialize the MPT model and tokenizer. |
|
|
|
This model contains custom Python scripts that are not yet included in the transformers library. Therefore, the |
|
non-standard `trust_remote_code=True` must be set when loading config and the model. The referenced remote code |
|
exists within the model artifact, which has been manually inspected for malicious code, downloaded with commit |
|
revision hash checking, and prepackaged within the script tarball. Therefore, the "remote" code is hosted on S3 |
|
via the SageMaker JumpStart deployment pipeline. This script will not perform requests to the public internet. |
|
Due to these steps and the fact that SageMaker JumpStart deploys models with network isolation enabled by |
|
default, these scripts are deemed safe to enable to `trust_remote_code=True` flag. |
|
|
|
Additionally, the accelerate library currently does not support the MPT model type. |
|
""" |
|
model_dir = properties[constants.MODEL_DIR] |
|
|
|
|
|
config = AutoConfig.from_pretrained(model_dir, trust_remote_code=True) |
|
config.attn_config[ATTN_IMPL] = TRITON |
|
maximum_sequence_length = os.environ.get(MAXIMUM_INPUT_SEQUENCE_LENGTH) |
|
if maximum_sequence_length is not None: |
|
config.update({MAX_SEQ_LEN: int(maximum_sequence_length)}) |
|
|
|
torch_dtype = get_torch_dtype_from_str(properties.get(DTYPE)) |
|
|
|
self.model = AutoModelForCausalLM.from_pretrained( |
|
model_dir, config=config, torch_dtype=torch_dtype, trust_remote_code=True |
|
) |
|
self.model.to(device="cuda:0") |
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(model_dir) |
|
self.tokenizer.pad_token = self.tokenizer.eos_token |
|
|
|
self.initialized = True |
|
|
|
@format_djl_output |
|
def inference(self, inputs: Input) -> Output: |
|
"""Perform inference for text generation task.""" |
|
input_data, model_kwargs = process_input(inputs, input_data_as_list=False) |
|
model_kwargs = add_stopping_criteria_to_model_kwargs(model_kwargs, self.tokenizer) |
|
model_output = generate_text(self.model, self.tokenizer, input_data, model_kwargs) |
|
return model_output_to_batch_output(model_output, model_kwargs) |
|
|
|
|
|
_service = MptPythonServiceTextGeneration() |
|
handle = create_handle(_service) |
|
|