Konstantin commited on
Commit
ab52e38
·
1 Parent(s): e0f33e4

add files and requirements

Browse files
Files changed (2) hide show
  1. handler.py +126 -0
  2. requirements.txt +5 -0
handler.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any
2
+ import torch
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
4
+ from fastapi.responses import StreamingResponse
5
+ import uuid
6
+ import time
7
+ import json
8
+ from threading import Thread
9
+
10
+ class EndpointHandler:
11
+ def __init__(self, path: str = "openai/gpt-oss-20b"):
12
+ # Load tokenizer and model
13
+ self.tokenizer = AutoTokenizer.from_pretrained(path)
14
+ self.model = AutoModelForCausalLM.from_pretrained(path)
15
+ self.model.eval()
16
+
17
+ # Determine the computation device
18
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
+ self.model.to(self.device)
20
+
21
+ def openai_id(prefix: str) -> str:
22
+ return f"{prefix}-{uuid.uuid4().hex[:24]}"
23
+
24
+ def format_non_stream(self, model: str, text: str, prompt_length: int, completion_length: int, total_tokens: int):
25
+ # Create OpenAI-compatible payload
26
+ return {
27
+ "id": self.openai_id("chatcmpl"),
28
+ "object": "chat.completion",
29
+ "created": int(time.time()),
30
+ "model": model,
31
+ "choices": [{
32
+ "index": 0,
33
+ "message": {"role": "assistant", "content": text},
34
+ "finish_reason": "stop"
35
+ }],
36
+ "usage": {
37
+ "prompt_tokens": prompt_length,
38
+ "completion_tokens": completion_length,
39
+ "total_tokens": total_tokens
40
+ }
41
+ }
42
+
43
+ def format_stream(self, model: str, token: str, usage) -> bytes:
44
+ payload = {
45
+ "id": self.openai_id("chatcmpl"),
46
+ "object": "chat.completion.chunk",
47
+ "created": int(time.time()),
48
+ "model": model,
49
+ "choices": [{
50
+ "index": 0,
51
+ "delta": {
52
+ "content": token,
53
+ "function_call": None,
54
+ "refusal": None,
55
+ "role": None,
56
+ "tool_calls": None
57
+ },
58
+ "finish_reason": None,
59
+ "logprobs": None
60
+ }],
61
+ "usage": usage
62
+ }
63
+
64
+ return f"data: {json.dumps(payload)}\n\n".encode('utf-8')
65
+
66
+ def generate(self, messages, model: str):
67
+ model_inputs = self.tokenizer(messages, return_tensors="pt").to(self.device)
68
+ full_output = self.model.generate(**model_inputs, max_new_tokens=2048)
69
+ generated_ids = [
70
+ output_ids[len(input_ids):]
71
+ for input_ids, output_ids in zip(model_inputs.input_ids, full_output)
72
+ ]
73
+ text = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=False)[0]
74
+
75
+ input_length = model_inputs.input_ids.shape[1] # Prompt tokens
76
+ output_length = full_output.shape[1] # Total tokens (prompt + completion)
77
+ completion_tokens = output_length - input_length
78
+
79
+ return self.format_non_stream(model, text, input_length, completion_tokens, output_length)
80
+
81
+ def stream(self, messages, model):
82
+ model_inputs = self.tokenizer(messages, return_tensors="pt").to(self.device)
83
+ input_len = model_inputs.input_ids.shape[1]
84
+ streamer = TextIteratorStreamer(
85
+ self.tokenizer,
86
+ skip_prompt=True,
87
+ skip_special_tokens=True
88
+ )
89
+
90
+ generation_kwargs = dict(
91
+ **model_inputs,
92
+ streamer=streamer,
93
+ max_new_tokens=2048
94
+ )
95
+
96
+ thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
97
+ thread.start()
98
+
99
+ completion_tokens = 0
100
+ for token in streamer:
101
+ # Count tokens in each chunk
102
+ token_ids = self.tokenizer.encode(token, add_special_tokens=False)
103
+ token_count = len(token_ids)
104
+ completion_tokens += token_count
105
+
106
+ yield self.format_stream(model, token, None)
107
+
108
+ # Final chunk with stop reason and token counts
109
+ yield self.format_stream(model, "", {
110
+ "prompt_tokens": input_len,
111
+ "completion_tokens": completion_tokens,
112
+ "total_tokens": input_len + completion_tokens
113
+ })
114
+
115
+ def __call__(self, data: Dict[str, Any]):
116
+ messages = data.get("messages")
117
+ model = data.get("model")
118
+ stream = data.get("stream", False)
119
+
120
+ if stream is False:
121
+ return self.generate(messages, model)
122
+ else:
123
+ return StreamingResponse(
124
+ self.stream(messages, model),
125
+ media_type="text/event-stream"
126
+ )
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+
2
+ torch
3
+ kernels
4
+ fastapi
5
+ triton