zidsi commited on
Commit
b64b8b9
·
1 Parent(s): 3d1458b

try stream

Browse files
Files changed (1) hide show
  1. app.py +8 -3
app.py CHANGED
@@ -3,7 +3,7 @@ import spaces
3
 
4
  # from huggingface_hub import InferenceClient
5
  from transformers import pipeline
6
- from transformers import AutoModelForCausalLM, AutoTokenizer
7
  import os
8
  HF_TOKEN = os.getenv('HF_TOKEN')
9
 
@@ -12,12 +12,17 @@ device = "cuda" # "cuda" or "cpu"
12
  tokenizer = AutoTokenizer.from_pretrained(checkpoint,token=HF_TOKEN)
13
  model = AutoModelForCausalLM.from_pretrained(checkpoint,token=HF_TOKEN)
14
  model.to(device)
 
15
  @spaces.GPU
16
  def predict(message, history):
17
  history.append({"role": "user", "content": message})
18
  input_text = tokenizer.apply_chat_template(history, tokenize=False)
19
- inputs = tokenizer.encode(input_text, return_tensors="pt").to(device)
20
- outputs = model.generate(inputs, max_new_tokens=100, temperature=0.2, top_p=0.9, do_sample=True)
 
 
 
 
21
  decoded = tokenizer.decode(outputs[0])
22
  response = decoded.split("[INST]")[-1].split("[/INST]")[-1]
23
  return response
 
3
 
4
  # from huggingface_hub import InferenceClient
5
  from transformers import pipeline
6
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
7
  import os
8
  HF_TOKEN = os.getenv('HF_TOKEN')
9
 
 
12
  tokenizer = AutoTokenizer.from_pretrained(checkpoint,token=HF_TOKEN)
13
  model = AutoModelForCausalLM.from_pretrained(checkpoint,token=HF_TOKEN)
14
  model.to(device)
15
+ streamer = TextStreamer(tok)
16
  @spaces.GPU
17
  def predict(message, history):
18
  history.append({"role": "user", "content": message})
19
  input_text = tokenizer.apply_chat_template(history, tokenize=False)
20
+ inputs = tokenizer.encode(input_text, return_tensors="pt").to(device)
21
+ # Use TextStreamer for streaming response
22
+ streamer = TextStreamer(tokenizer)
23
+ outputs = model.generate(inputs, max_new_tokens=512, temperature=0.2, top_p=0.9, do_sample=True, streamer=streamer)
24
+ # Despite returning the usual output, the streamer will also print the generated text to stdout.
25
+
26
  decoded = tokenizer.decode(outputs[0])
27
  response = decoded.split("[INST]")[-1].split("[/INST]")[-1]
28
  return response