blip2-flan-t5-xxl / handler.py
sarang-shrivastava's picture
Initial handler file
bedc493
raw
history blame
4.64 kB
from typing import Dict, List, Any
import transformers
from transformers import AutoTokenizer
import torch
from transformers import StoppingCriteria, StoppingCriteriaList
tokenizer = AutoTokenizer.from_pretrained(
"",
trust_remote_code=True
)
if tokenizer.pad_token_id is None:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = 'left'
stop_token_ids = tokenizer.convert_tokens_to_ids(["<|endoftext|>"])
# Define a custom stopping criteria
class StopOnTokens(StoppingCriteria):
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
for stop_id in stop_token_ids:
if input_ids[0][-1] == stop_id:
return True
return False
class EndpointHandler():
def __init__(self, path=""):
self.torch_dtype = torch.bfloat16
# self.torch_dtype = torch.float32
self.tokenizer = tokenizer
self.config = transformers.AutoConfig.from_pretrained(
path,
trust_remote_code=True
)
# self.config.attn_config['attn_impl'] = 'triton'
# self.config.update({"max_seq_len": 4096})
self.model = transformers.AutoModelForCausalLM.from_pretrained(
path,
config=self.config,
torch_dtype=self.torch_dtype,
trust_remote_code=True
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model.eval()
self.model.to(device=device, dtype=self.torch_dtype)
self.generate_kwargs = {
'max_new_tokens': 512,
'temperature': 0.0001,
'top_p': 1.0,
'top_k': 0,
'use_cache': True,
'do_sample': True,
'eos_token_id': self.tokenizer.eos_token_id,
'pad_token_id': self.tokenizer.pad_token_id,
"repetition_penalty": 1.1
}
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
data args:
inputs (:obj: `str` | `PIL.Image` | `np.array`)
kwargs
Return:
A :obj:`list` | `dict`: will be serialized and returned
"""
# streamer = TextIteratorStreamer(
# self.tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True
# )
stop = StopOnTokens()
## Model Parameters
self.generate_kwargs['max_new_tokens'] = data['max_new_tokens'] if 'max_new_tokens' in data else self.generate_kwargs['max_new_tokens']
self.generate_kwargs['temperature'] = data['temperature'] if 'temperature' in data else self.generate_kwargs['temperature']
self.generate_kwargs['top_p'] = data['top_p'] if 'top_p' in data else self.generate_kwargs['top_p']
self.generate_kwargs['top_k'] = data['top_k'] if 'top_k' in data else self.generate_kwargs['top_k']
self.generate_kwargs['do_sample'] = data['do_sample'] if 'do_sample' in data else self.generate_kwargs['do_sample']
self.generate_kwargs['repetition_penalty'] = data['repetition_penalty'] if 'repetition_penalty' in data else self.generate_kwargs['repetition_penalty']
## Add the streamer and stopping criteria
# self.generate_kwargs['streamer'] = streamer
self.generate_kwargs['stopping_criteria'] = StoppingCriteriaList([stop])
## Prepare the inputs
inputs = data.pop("inputs",data)
input_ids = self.tokenizer(inputs, return_tensors="pt").input_ids
input_ids = input_ids.to(self.model.device)
# encoded_inp = self.tokenizer(inputs, return_tensors='pt', padding=True)
# for key, value in encoded_inp.items():
# encoded_inp[key] = value.to('cuda:0')
## Invoke the model
# with torch.no_grad():
# gen_tokens = self.model.generate(
# input_ids=encoded_inp['input_ids'],
# attention_mask=encoded_inp['attention_mask'],
# **generate_kwargs,
# )
# ## Decode using tokenizer
# decoded_gen = self.tokenizer.batch_decode(gen_tokens, skip_special_tokens=True)
with torch.no_grad():
output_ids = self.model.generate(input_ids, **self.generate_kwargs)
# Slice the output_ids tensor to get only new tokens
new_tokens = output_ids[0, len(input_ids[0]) :]
output_text = self.tokenizer.decode(new_tokens, skip_special_tokens=True)
return [{"gen_text":output_text, "input_text":inputs}]