File size: 1,799 Bytes
6fc8604
68052f8
 
0d1d49c
d00a5bc
b64b8b9
628d88b
 
6fc8604
b244f76
d00a5bc
 
68052f8
 
226c90f
68052f8
5ab747e
d00a5bc
 
b64b8b9
 
5ab747e
 
b64b8b9
 
d00a5bc
3d1458b
d00a5bc
6fc8604
 
 
 
 
d00a5bc
5ab747e
4983f9e
 
5ab747e
 
 
 
4983f9e
5ab747e
 
 
6fc8604
5ab747e
6fc8604
 
2963ee9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import gradio as gr
import spaces

# from huggingface_hub import InferenceClient
from transformers import pipeline
from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
import os
HF_TOKEN = os.getenv('HF_TOKEN')

checkpoint = "zidsi/SLlamica_PT4SFT_v2"
device = "cuda"  # "cuda" or "cpu"
tokenizer = AutoTokenizer.from_pretrained(checkpoint,token=HF_TOKEN)
model = AutoModelForCausalLM.from_pretrained(checkpoint,token=HF_TOKEN)
model.to(device)

@spaces.GPU
def predict(message, history,max_new_tokens,temperature,top_p):
    history.append({"role": "user", "content": message})
    input_text = tokenizer.apply_chat_template(history, tokenize=False)
    inputs = tokenizer.encode(input_text, return_tensors="pt").to(device) 
    # Use TextStreamer for streaming response
    # streamer = TextStreamer(tokenizer)
    outputs = model.generate(inputs, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p, do_sample=True)
    # Despite returning the usual output, the streamer will also print the generated text to stdout.

    decoded = tokenizer.decode(outputs[0])
    response = decoded.split("[INST]")[-1].split("[/INST]")[-1]
    return response

"""
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
"""
demo = gr.ChatInterface(
    predict, type="messages",
additional_inputs=[
        gr.Slider(minimum=1, maximum=2048, value=256, step=1, label="Max new tokens"),
        gr.Slider(minimum=0.1, maximum=1.0, value=0.8, step=0.05, label="Temperature"),
        gr.Slider(
            minimum=0.1,
            maximum=1.0,
            value=0.95,
            step=0.01,
            label="Top-p (nucleus sampling)",
        ),
    ],    
)


if __name__ == "__main__":
    demo.launch()