import logging import traceback from typing import Any from typing import Dict from typing import List from typing import Optional from typing import Union import torch from constants import constants from djl_python import Input from djl_python import Output from inference_helper import inference_helper_model_tokenizer from input import process_input from transformers import AutoModelForCausalLM from transformers import AutoTokenizer from transformers import PreTrainedModel from transformers import PreTrainedTokenizer from transformers import PreTrainedTokenizerBase from transformers import StoppingCriteria from transformers import StoppingCriteriaList class StopWordsCriteria(StoppingCriteria): """A text generation stopping criteria when output sequence contains any specified stop words.""" def __init__(self, tokenizer: PreTrainedTokenizerBase, stop_words: List[str]) -> None: """Initialize stopping criteria.""" self._tokenizer = tokenizer self._stop_words = stop_words self._partial_result = "" def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> bool: """Return True if any stop word is present in generated sequence.""" text = self._tokenizer.decode(input_ids[0, -1]) self._partial_result += text for stop_word in self._stop_words: if stop_word in self._partial_result: return True return False def create_stopping_criteria_list(tokenizer: PreTrainedTokenizerBase, stop_words: List[str]) -> StoppingCriteriaList: """Create a StoppingCriteriaList to be used as generation payload input.""" stop_criteria = StopWordsCriteria(tokenizer, stop_words) return StoppingCriteriaList([stop_criteria]) class Service(object): """Define Service class for text generation. This class aims to have customized model loading and inference methods. """ def __init__(self): """Initialize model and tokenizer""" self.model: Optional[PreTrainedModel] = None self.tokenizer: Optional[PreTrainedTokenizer] = None def load_model(self, properties: dict) -> None: """Load model, tokenizer from disk to memory as instance attributes. Args: properties (dict): a python dictionary of model parameters. """ model_location = properties["model_dir"] logging.info(f"Loading model from {model_location}") tokenizer = AutoTokenizer.from_pretrained(model_location) tokenizer.pad_token = tokenizer.eos_token model = AutoModelForCausalLM.from_pretrained(model_location, torch_dtype=torch.float16, device_map="auto") model.requires_grad_(False) model.eval() self.model = model self.tokenizer = tokenizer def modify_model_kwargs(self, model_kwargs: Dict[str, Any]) -> None: """Provide script-based modification of model kwargs. By default, this method injects stopping criteria into text generation. Since payload parameters are serialized when invoking the endpoint, this method initializes any non-serializable kwargs for text generation. """ if constants.STOPPING_CRITERIA in model_kwargs: stopping_criteria = model_kwargs.pop(constants.STOPPING_CRITERIA) stopping_criteria_list = create_stopping_criteria_list(self.tokenizer, stopping_criteria) model_kwargs[constants.STOPPING_CRITERIA] = stopping_criteria_list def inference(self, inputs: Input) -> Union[List[Dict[str, List]], List[List[Dict[str, List]]]]: """Conduct inference based on inputs. Args: inputs (djl_python.inputs.Input): input containing payload and content type. Returns: results (Union[List[Dict[str, List]], List[List[Dict[str, List]]]]): if the length of input list being one, the return output is a list of dictionary, where the length of dictionary corresponds to the number of return sequences; if the length of input list being more than one, (i.e., batch inference), the return output is a list of list where inner list contains one or multiple dictionaries. """ try: input_data, model_kwargs = process_input(inputs=inputs, text_input_for_bloom_model=True) self.modify_model_kwargs(model_kwargs) content_type = inputs.get_property("Content-Type") return inference_helper_model_tokenizer(input_data, self.model, self.tokenizer, content_type, model_kwargs) except Exception as e: logging.exception(f"Failed to do inference: {e}; {traceback.format_exc()}") results = Output().error((str(e))) return results _service = Service() def handle(inputs: Input) -> Optional[Output]: """Define handler method for model. Args: inputs (djl_python.inputs.Input): input containing payload and content type. Returns: outputs (djl_python.inputs.Output): model prediction output. """ if inputs.is_empty(): # Model server makes an empty call to warmup the model on startup _service.load_model(inputs.get_properties()) return None results = _service.inference(inputs) return Output().add(results)