Spaces:
Sleeping
Sleeping
try stream
Browse files
app.py
CHANGED
@@ -3,7 +3,7 @@ import spaces
|
|
3 |
|
4 |
# from huggingface_hub import InferenceClient
|
5 |
from transformers import pipeline
|
6 |
-
from transformers import AutoModelForCausalLM, AutoTokenizer
|
7 |
import os
|
8 |
HF_TOKEN = os.getenv('HF_TOKEN')
|
9 |
|
@@ -12,12 +12,17 @@ device = "cuda" # "cuda" or "cpu"
|
|
12 |
tokenizer = AutoTokenizer.from_pretrained(checkpoint,token=HF_TOKEN)
|
13 |
model = AutoModelForCausalLM.from_pretrained(checkpoint,token=HF_TOKEN)
|
14 |
model.to(device)
|
|
|
15 |
@spaces.GPU
|
16 |
def predict(message, history):
|
17 |
history.append({"role": "user", "content": message})
|
18 |
input_text = tokenizer.apply_chat_template(history, tokenize=False)
|
19 |
-
inputs = tokenizer.encode(input_text, return_tensors="pt").to(device)
|
20 |
-
|
|
|
|
|
|
|
|
|
21 |
decoded = tokenizer.decode(outputs[0])
|
22 |
response = decoded.split("[INST]")[-1].split("[/INST]")[-1]
|
23 |
return response
|
|
|
3 |
|
4 |
# from huggingface_hub import InferenceClient
|
5 |
from transformers import pipeline
|
6 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
|
7 |
import os
|
8 |
HF_TOKEN = os.getenv('HF_TOKEN')
|
9 |
|
|
|
12 |
tokenizer = AutoTokenizer.from_pretrained(checkpoint,token=HF_TOKEN)
|
13 |
model = AutoModelForCausalLM.from_pretrained(checkpoint,token=HF_TOKEN)
|
14 |
model.to(device)
|
15 |
+
streamer = TextStreamer(tok)
|
16 |
@spaces.GPU
|
17 |
def predict(message, history):
|
18 |
history.append({"role": "user", "content": message})
|
19 |
input_text = tokenizer.apply_chat_template(history, tokenize=False)
|
20 |
+
inputs = tokenizer.encode(input_text, return_tensors="pt").to(device)
|
21 |
+
# Use TextStreamer for streaming response
|
22 |
+
streamer = TextStreamer(tokenizer)
|
23 |
+
outputs = model.generate(inputs, max_new_tokens=512, temperature=0.2, top_p=0.9, do_sample=True, streamer=streamer)
|
24 |
+
# Despite returning the usual output, the streamer will also print the generated text to stdout.
|
25 |
+
|
26 |
decoded = tokenizer.decode(outputs[0])
|
27 |
response = decoded.split("[INST]")[-1].split("[/INST]")[-1]
|
28 |
return response
|