File size: 3,331 Bytes
cd5a409
769181d
 
 
 
8fba5ff
769181d
 
 
 
 
8fba5ff
 
769181d
 
0f4223a
8f45a8b
769181d
8f45a8b
 
 
 
 
 
 
 
769181d
 
8f45a8b
769181d
 
 
 
 
39caeca
769181d
 
 
 
 
 
 
 
 
3e11aff
769181d
d9a7313
3e11aff
ae744dd
769181d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73f027a
769181d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3e11aff
79c25c9
769181d
 
 
3e11aff
769181d
 
 
6070e2a
769181d
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import gradio as gr
import pprint
import subprocess
from threading import Thread
from transformers import AutoTokenizer, TextIteratorStreamer
from transformers import AutoTokenizer, AutoModelForCausalLM

result = subprocess.run(["lscpu"], text=True, capture_output=True)
pprint.pprint(result.stdout)


tokenizer = AutoTokenizer.from_pretrained("suriya7/Gemma-2b-SFT")
model = AutoModelForCausalLM.from_pretrained("suriya7/Gemma-2b-SFT")


def run_generation(user_text, top_p, temperature, top_k, max_new_tokens):
    alpaca_prompt = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.

    ### Instruction:
    {}
    
    ### Input:
    {}
    
    ### Response:
    {}"""

    inputs = tokenizer(
    [
    alpaca_prompt.format(
        "You are an AI assistant. Please ensure that the answers conclude with an end-of-sequence (EOS) token.", # instruction
        user_text, # input goes here
        "", # output - leave this blank for generation!
    )
    ], return_tensors = "pt")

    # Start generation on a separate thread, so that we don't block the UI. The text is pulled from the streamer
    # in the main thread. Adds timeout to the streamer to handle exceptions in the generation thread.
    streamer = TextIteratorStreamer(
        tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True
    )
    generate_kwargs = dict(
        inputs,
        streamer=streamer,
        max_new_tokens=250,
        do_sample=True,
        repetition_penalty=1.5,
        temperature=0.7,
        top_k=2,
    )
    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()

    # Pull the generated text from the streamer, and update the model output.
    model_output = ""
    for new_text in streamer:
        model_output += new_text
        yield model_output
    return model_output

def reset_textbox():
    return gr.update(value="")


with gr.Blocks() as demo:

    with gr.Row():
        with gr.Column(scale=4):
            user_text = gr.Textbox(
                label="User input",
            )
            model_output = gr.Textbox(label="Model output", lines=10, interactive=False)
            button_submit = gr.Button(value="Submit")

        with gr.Column(scale=1):
            max_new_tokens = gr.Slider(
                minimum=1,
                maximum=1000,
                value=250,
                step=1,
                interactive=True,
                label="Max New Tokens",
            )
        
            top_k = gr.Slider(
                minimum=1,
                maximum=50,
                value=50,
                step=1,
                interactive=True,
                label="Top-k",
            )
            temperature = gr.Slider(
                minimum=0.1,
                maximum=5.0,
                value=0.8,
                step=0.1,
                interactive=True,
                label="Temperature",
            )

    user_text.submit(
        run_generation,
        [user_text],
         model_output,
    )
    button_submit.click(
        run_generation,
        [user_text],
        model_output,
    )

    demo.queue(max_size=32).launch(server_name="0.0.0.0")
    # For local use:
    # demo.launch(server_name="0.0.0.0")