from typing import Dict from typing import List from typing import Union import torch from constants import constants from djl_python import Input from djl_python import Output from transformers import AutoModelForCausalLM from transformers import AutoTokenizer def inference_helper_model_tokenizer( input_data: List, model: AutoModelForCausalLM, tokenizer: AutoTokenizer, content_type: str, model_kwargs: Dict ) -> Union[List[Dict[str, List]], List[List[Dict[str, List]]]]: """Conduct inference based on inputs. Args: input_data (list): a python list of string. model (AutoModelForCausalLM): model for doing the inference tokenizer (AutoTokenizer): tokenizer for the inference content_type (str): request content type model_kwargs (dict): a python dictionary of parameters. Returns: results (Union[List[Dict[str, List]], List[List[Dict[str, List]]]]): if the length of input list being one, the return output is a list of dictionary, where the length of dictionary corresponds to the number of return sequences; if the length of input list being more than one, (i.e., batch inference), the return output is a list of list where inner list contains one or multiple dictionaries. """ input_tokens = tokenizer.batch_encode_plus(input_data, return_tensors=constants.RETURN_TENSOR_TYPE, padding=True) for t in input_tokens: if torch.is_tensor(input_tokens[t]): input_tokens[t] = input_tokens[t].to(torch.cuda.current_device()) outputs = model.generate(**input_tokens, **model_kwargs) outputs: List[str] = tokenizer.batch_decode(outputs, skip_special_tokens=True) if len(outputs) > 1: results: List[List[Dict[str, List]]] = [] num_return_seq = model_kwargs.get(constants.NUM_RETURN_SEQUENCES, 1) for i in range(0, len(outputs), num_return_seq): res_tmp: List[dict] = [] for j in range(i, min(i + num_return_seq, len(outputs))): res_tmp.append({constants.GENERATED_TEXT: outputs[j]}) results.append(res_tmp) else: if content_type == constants.APPLICATION_X_TEXT: results = [{constants.GENERATED_TEXT: outputs}] else: results = [[{constants.GENERATED_TEXT: outputs}]] return results