File size: 1,771 Bytes
37bde29
0753a4b
37bde29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3dea10e
37bde29
 
 
 
 
5f6c807
 
37bde29
 
 
 
 
 
 
 
 
 
 
 
 
 
f096e93
 
37bde29
9522763
 
 
c33d1af
9522763
 
 
37bde29
ecaf5c0
5f6c807
 
 
 
ecaf5c0
5f6c807
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
from typing import Any
import time
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig


def extract_assistant_response_simple(response_text):
    # Split by assistant header and eot
    parts = response_text.split("<|start_header_id|>assistant<|end_header_id|>")[
        1
    ].split("<|eot_id|>")[0]
    return parts.strip()


class EndpointHandler:
    def __init__(self, path=""):
        self.tokenizer = AutoTokenizer.from_pretrained(path)
        self.model = AutoModelForCausalLM.from_pretrained(
            path,
            torch_dtype="auto",
        )

    def __call__(self, data: Any):
        start = time.perf_counter()

        text = data.pop("inputs", data)

        messages = [
            {"role": "system", "content": "You are a helpful assistant."},
            {"role": "user", "content": text},
        ]

        inputs = self.tokenizer.apply_chat_template(
            messages,
            tokenize=True,
            add_generation_prompt=True,  # Must add for generation
            return_tensors="pt",
        ).to("cuda")

        print(f"inputs={inputs}")

        with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
            outputs = self.model.generate(
                input_ids=inputs,
                max_new_tokens=64,
                use_cache=False,
                temperature=1.5,
                min_p=0.1,
            )

        response_length = len(outputs[0])
        response = extract_assistant_response_simple(self.tokenizer.decode(outputs[0]))
        end = time.perf_counter()
        elapsed = end - start

        return {"response": response, "response_token_length": response_length, "elapsed": elapsed}