import os from pathlib import Path from typing import Any from typing import Dict import torch from constants import constants from djl_python.inputs import Input from djl_python.outputs import Output from sagemaker_jumpstart_huggingface_script_utilities.djl_python.dtypes import get_torch_dtype_from_str from sagemaker_jumpstart_huggingface_script_utilities.djl_python.handle import create_handle from sagemaker_jumpstart_huggingface_script_utilities.djl_python.inference.textgeneration import format_djl_output from sagemaker_jumpstart_huggingface_script_utilities.djl_python.inference.textgeneration import generate_text from sagemaker_jumpstart_huggingface_script_utilities.djl_python.inference.textgeneration import ( model_output_to_batch_output, ) from sagemaker_jumpstart_huggingface_script_utilities.djl_python.inference.textgeneration import process_input from sagemaker_jumpstart_huggingface_script_utilities.payload.stopping_criteria import ( add_stopping_criteria_to_model_kwargs, ) from sagemaker_jumpstart_script_utilities.subprocess import run_with_error_handling from transformers import AutoConfig from transformers import AutoModelForCausalLM from transformers import AutoTokenizer FLASH_ATTENTION_WHEEL_FILENAME = "flash_attn-1.0.5-cp39-cp39-linux_x86_64.whl" MAXIMUM_INPUT_SEQUENCE_LENGTH = "MAXIMUM_INPUT_SEQUENCE_LENGTH" DTYPE = "dtype" MAX_SEQ_LEN = "max_seq_len" ATTN_IMPL = "attn_impl" TRITON = "triton" LIB = "lib" EXTRA_DEPENDENCIES = "extra_dependencies" EXTRA_DEPENDENCIES_PATH = Path(__file__).parent / LIB / EXTRA_DEPENDENCIES # As of the time of this script creation, the maintainers of the flash-attn package do not distribute a bdist. To # avoid an approximately 10 minute source build, a pre-built wheel is hosted as an extra dependency. run_with_error_handling(["pip", "install", str(EXTRA_DEPENDENCIES_PATH / FLASH_ATTENTION_WHEEL_FILENAME)]) class MptPythonServiceTextGeneration: """A service object for the MPT model family using the DJL Python engine. This set of MPT models is not compatible with the T4 GPUs on the G4 instance family because some Triton kernels are not compatible with architectures older than Ampere ([reference](https://github.com/microsoft/DeepSpeed-MII/issues/170#issuecomment-1526277566)). This set of MPT models uses the [FlashAttention](https://arxiv.org/pdf/2205.14135.pdf) mechanism with implementation that requires the Python environment to include the following libraries: - [flash-attn](https://github.com/HazyResearch/flash-attention) - [triton](https://github.com/openai/triton) - [einops](https://github.com/arogozhnikov/einops) """ def __init__(self) -> None: """Set initialization flag to False, model to be initialized upon invocation of initialize method.""" self.initialized = False def initialize(self, properties: Dict[str, Any]) -> None: """Initialize the MPT model and tokenizer. This model contains custom Python scripts that are not yet included in the transformers library. Therefore, the non-standard `trust_remote_code=True` must be set when loading config and the model. The referenced remote code exists within the model artifact, which has been manually inspected for malicious code, downloaded with commit revision hash checking, and prepackaged within the script tarball. Therefore, the "remote" code is hosted on S3 via the SageMaker JumpStart deployment pipeline. This script will not perform requests to the public internet. Due to these steps and the fact that SageMaker JumpStart deploys models with network isolation enabled by default, these scripts are deemed safe to enable to `trust_remote_code=True` flag. Additionally, the accelerate library currently does not support the MPT model type. """ model_dir = properties[constants.MODEL_DIR] # View docstring for details on this `trust_remote_code=True` flag. config = AutoConfig.from_pretrained(model_dir, trust_remote_code=True) config.attn_config[ATTN_IMPL] = TRITON maximum_sequence_length = os.environ.get(MAXIMUM_INPUT_SEQUENCE_LENGTH) if maximum_sequence_length is not None: config.update({MAX_SEQ_LEN: int(maximum_sequence_length)}) torch_dtype = get_torch_dtype_from_str(properties.get(DTYPE)) self.model = AutoModelForCausalLM.from_pretrained( model_dir, config=config, torch_dtype=torch_dtype, trust_remote_code=True ) self.model.to(device="cuda:0") self.tokenizer = AutoTokenizer.from_pretrained(model_dir) self.tokenizer.pad_token = self.tokenizer.eos_token self.initialized = True @format_djl_output def inference(self, inputs: Input) -> Output: """Perform inference for text generation task.""" input_data, model_kwargs = process_input(inputs, input_data_as_list=False) model_kwargs = add_stopping_criteria_to_model_kwargs(model_kwargs, self.tokenizer) model_output = generate_text(self.model, self.tokenizer, input_data, model_kwargs) return model_output_to_batch_output(model_output, model_kwargs) _service = MptPythonServiceTextGeneration() handle = create_handle(_service)