import logging from typing import Dict from typing import List from typing import Union from constants import constants from djl_python.inputs import Input from hp_validation import _update_num_beams from hp_validation import _validate_payload from transformers import set_seed def format_input_for_task(input_values: Union[str, List[str]]) -> List[str]: """Format input string into a list for text generation task Args: input_values: either a text string or list of text string. Returns: input_values (list): list of text string. """ if not isinstance(input_values, list): input_values = [input_values] return input_values def process_input(inputs: Input, text_input_for_bloom_model: bool = True) -> Union[List, Dict]: """Process input based on content type. Parse the input based on Content-Type of application/json or application/x-text Args: inputs (djl_python.inputs.Input): input containing payload and content type. text_input_for_bloom_model: whether text string need special handling for bloom model Returns: input_data (list): a python list of string. model_kwargs (dict): a python dictionary of parameters. """ content_type = inputs.get_property("Content-Type") model_kwargs = {} if content_type == constants.APPLICATION_JSON: try: json_input = inputs.get_as_json() _validate_payload(json_input) json_input = _update_num_beams(json_input) if constants.SEED in json_input: set_seed(json_input[constants.SEED]) del json_input[constants.SEED] if isinstance(json_input, dict): input_data = format_input_for_task(json_input.pop(constants.TEXT_INPUTS)) model_kwargs = json_input else: input_data = json_input except Exception: logging.exception( f"Failed to parse input payload. For content_type={constants.APPLICATION_JSON}, input " f"payload must be a json encoded dictionary with keys {constants.ALL_PARAM_NAMES}." ) raise elif content_type == constants.APPLICATION_X_TEXT: try: if text_input_for_bloom_model: input_data = format_input_for_task(inputs.get_as_string()) else: input_data = inputs.get_as_string() except Exception: logging.exception( f"Failed to parse input payload. For content_type={constants.APPLICATION_X_TEXT}, input " f"payload must be a string encoded in utf-8 format." ) raise else: raise ValueError('{{"error": "unsupported content type {}"}}'.format(content_type or "unknown")) return input_data, model_kwargs