|
from typing import Dict, List, Any |
|
import transformers |
|
from transformers import AutoTokenizer |
|
import torch |
|
|
|
from transformers import StoppingCriteria, StoppingCriteriaList |
|
|
|
tokenizer = AutoTokenizer.from_pretrained( |
|
"", |
|
trust_remote_code=True |
|
) |
|
if tokenizer.pad_token_id is None: |
|
tokenizer.pad_token = tokenizer.eos_token |
|
tokenizer.padding_side = 'left' |
|
|
|
stop_token_ids = tokenizer.convert_tokens_to_ids(["<|endoftext|>"]) |
|
|
|
|
|
|
|
class StopOnTokens(StoppingCriteria): |
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: |
|
for stop_id in stop_token_ids: |
|
if input_ids[0][-1] == stop_id: |
|
return True |
|
return False |
|
|
|
|
|
class EndpointHandler(): |
|
def __init__(self, path=""): |
|
|
|
self.torch_dtype = torch.bfloat16 |
|
|
|
|
|
self.tokenizer = tokenizer |
|
|
|
self.config = transformers.AutoConfig.from_pretrained( |
|
path, |
|
trust_remote_code=True |
|
) |
|
|
|
|
|
|
|
|
|
self.model = transformers.AutoModelForCausalLM.from_pretrained( |
|
path, |
|
config=self.config, |
|
torch_dtype=self.torch_dtype, |
|
trust_remote_code=True |
|
) |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
self.model.eval() |
|
self.model.to(device=device, dtype=self.torch_dtype) |
|
|
|
self.generate_kwargs = { |
|
'max_new_tokens': 512, |
|
'temperature': 0.0001, |
|
'top_p': 1.0, |
|
'top_k': 0, |
|
'use_cache': True, |
|
'do_sample': True, |
|
'eos_token_id': self.tokenizer.eos_token_id, |
|
'pad_token_id': self.tokenizer.pad_token_id, |
|
"repetition_penalty": 1.1 |
|
} |
|
|
|
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
|
""" |
|
data args: |
|
inputs (:obj: `str` | `PIL.Image` | `np.array`) |
|
kwargs |
|
Return: |
|
A :obj:`list` | `dict`: will be serialized and returned |
|
""" |
|
|
|
|
|
|
|
|
|
stop = StopOnTokens() |
|
|
|
|
|
|
|
self.generate_kwargs['max_new_tokens'] = data['max_new_tokens'] if 'max_new_tokens' in data else self.generate_kwargs['max_new_tokens'] |
|
self.generate_kwargs['temperature'] = data['temperature'] if 'temperature' in data else self.generate_kwargs['temperature'] |
|
self.generate_kwargs['top_p'] = data['top_p'] if 'top_p' in data else self.generate_kwargs['top_p'] |
|
self.generate_kwargs['top_k'] = data['top_k'] if 'top_k' in data else self.generate_kwargs['top_k'] |
|
self.generate_kwargs['do_sample'] = data['do_sample'] if 'do_sample' in data else self.generate_kwargs['do_sample'] |
|
self.generate_kwargs['repetition_penalty'] = data['repetition_penalty'] if 'repetition_penalty' in data else self.generate_kwargs['repetition_penalty'] |
|
|
|
|
|
|
|
self.generate_kwargs['stopping_criteria'] = StoppingCriteriaList([stop]) |
|
|
|
|
|
inputs = data.pop("inputs",data) |
|
input_ids = self.tokenizer(inputs, return_tensors="pt").input_ids |
|
input_ids = input_ids.to(self.model.device) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
output_ids = self.model.generate(input_ids, **self.generate_kwargs) |
|
|
|
new_tokens = output_ids[0, len(input_ids[0]) :] |
|
output_text = self.tokenizer.decode(new_tokens, skip_special_tokens=True) |
|
|
|
return [{"gen_text":output_text, "input_text":inputs}] |
|
|