from typing import Dict, List, Any import transformers from transformers import AutoTokenizer import torch from transformers import StoppingCriteria, StoppingCriteriaList tokenizer = AutoTokenizer.from_pretrained( "", trust_remote_code=True ) if tokenizer.pad_token_id is None: tokenizer.pad_token = tokenizer.eos_token tokenizer.padding_side = 'left' stop_token_ids = tokenizer.convert_tokens_to_ids(["<|endoftext|>"]) # Define a custom stopping criteria class StopOnTokens(StoppingCriteria): def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: for stop_id in stop_token_ids: if input_ids[0][-1] == stop_id: return True return False class EndpointHandler(): def __init__(self, path=""): self.torch_dtype = torch.bfloat16 # self.torch_dtype = torch.float32 self.tokenizer = tokenizer self.config = transformers.AutoConfig.from_pretrained( path, trust_remote_code=True ) # self.config.attn_config['attn_impl'] = 'triton' # self.config.update({"max_seq_len": 4096}) self.model = transformers.AutoModelForCausalLM.from_pretrained( path, config=self.config, torch_dtype=self.torch_dtype, trust_remote_code=True ) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.model.eval() self.model.to(device=device, dtype=self.torch_dtype) self.generate_kwargs = { 'max_new_tokens': 512, 'temperature': 0.0001, 'top_p': 1.0, 'top_k': 0, 'use_cache': True, 'do_sample': True, 'eos_token_id': self.tokenizer.eos_token_id, 'pad_token_id': self.tokenizer.pad_token_id, "repetition_penalty": 1.1 } def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: """ data args: inputs (:obj: `str` | `PIL.Image` | `np.array`) kwargs Return: A :obj:`list` | `dict`: will be serialized and returned """ # streamer = TextIteratorStreamer( # self.tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True # ) stop = StopOnTokens() ## Model Parameters self.generate_kwargs['max_new_tokens'] = data['max_new_tokens'] if 'max_new_tokens' in data else self.generate_kwargs['max_new_tokens'] self.generate_kwargs['temperature'] = data['temperature'] if 'temperature' in data else self.generate_kwargs['temperature'] self.generate_kwargs['top_p'] = data['top_p'] if 'top_p' in data else self.generate_kwargs['top_p'] self.generate_kwargs['top_k'] = data['top_k'] if 'top_k' in data else self.generate_kwargs['top_k'] self.generate_kwargs['do_sample'] = data['do_sample'] if 'do_sample' in data else self.generate_kwargs['do_sample'] self.generate_kwargs['repetition_penalty'] = data['repetition_penalty'] if 'repetition_penalty' in data else self.generate_kwargs['repetition_penalty'] ## Add the streamer and stopping criteria # self.generate_kwargs['streamer'] = streamer self.generate_kwargs['stopping_criteria'] = StoppingCriteriaList([stop]) ## Prepare the inputs inputs = data.pop("inputs",data) input_ids = self.tokenizer(inputs, return_tensors="pt").input_ids input_ids = input_ids.to(self.model.device) # encoded_inp = self.tokenizer(inputs, return_tensors='pt', padding=True) # for key, value in encoded_inp.items(): # encoded_inp[key] = value.to('cuda:0') ## Invoke the model # with torch.no_grad(): # gen_tokens = self.model.generate( # input_ids=encoded_inp['input_ids'], # attention_mask=encoded_inp['attention_mask'], # **generate_kwargs, # ) # ## Decode using tokenizer # decoded_gen = self.tokenizer.batch_decode(gen_tokens, skip_special_tokens=True) with torch.no_grad(): output_ids = self.model.generate(input_ids, **self.generate_kwargs) # Slice the output_ids tensor to get only new tokens new_tokens = output_ids[0, len(input_ids[0]) :] output_text = self.tokenizer.decode(new_tokens, skip_special_tokens=True) return [{"gen_text":output_text, "input_text":inputs}]