suriya7 commited on
Commit
769181d
·
verified ·
1 Parent(s): cd5a409

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +121 -1
app.py CHANGED
@@ -1,3 +1,123 @@
1
  import gradio as gr
 
 
 
 
2
 
3
- gr.load("models/suriya7/Gemma-2b-SFT").launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import pprint
3
+ import subprocess
4
+ from threading import Thread
5
+ from transformers import AutoTokenizer, TextIteratorStreamer
6
 
7
+
8
+ result = subprocess.run(["lscpu"], text=True, capture_output=True)
9
+ pprint.pprint(result.stdout)
10
+
11
+
12
+ checkpoint = "suriya7/Gemma-2b-SFT"
13
+ tokenizer = AutoTokenizer.from_pretrained(checkpoint)
14
+ model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)
15
+
16
+
17
+ ef run_generation(user_text, top_p, temperature, top_k, max_new_tokens):
18
+ alpaca_prompt = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
19
+
20
+ ### Instruction:
21
+ {}
22
+
23
+ ### Input:
24
+ {}
25
+
26
+ ### Response:
27
+ {}"""
28
+
29
+ inputs = tokenizer(
30
+ [
31
+ alpaca_prompt.format(
32
+ "You are an AI assistant. Please ensure that the answers conclude with an end-of-sequence (EOS) token.", # instruction
33
+ user_text, # input goes here
34
+ "", # output - leave this blank for generation!
35
+ )
36
+ ], return_tensors = "pt",return_dict=True)
37
+
38
+ # Start generation on a separate thread, so that we don't block the UI. The text is pulled from the streamer
39
+ # in the main thread. Adds timeout to the streamer to handle exceptions in the generation thread.
40
+ streamer = TextIteratorStreamer(
41
+ tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True
42
+ )
43
+ generate_kwargs = dict(
44
+ inputs,
45
+ streamer=streamer,
46
+ max_new_tokens=max_new_tokens,
47
+ do_sample=True,
48
+ top_p=top_p,
49
+ temperature=float(temperature),
50
+ top_k=top_k,
51
+ )
52
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
53
+ t.start()
54
+
55
+ # Pull the generated text from the streamer, and update the model output.
56
+ model_output = ""
57
+ for new_text in streamer:
58
+ model_output += new_text
59
+ yield model_output
60
+ return model_output
61
+
62
+ def reset_textbox():
63
+ return gr.update(value="")
64
+
65
+
66
+ with gr.Blocks() as demo:
67
+
68
+ with gr.Row():
69
+ with gr.Column(scale=4):
70
+ user_text = gr.Textbox(
71
+ label="User input",
72
+ )
73
+ model_output = gr.Textbox(label="Model output", lines=10, interactive=False)
74
+ button_submit = gr.Button(value="Submit")
75
+
76
+ with gr.Column(scale=1):
77
+ max_new_tokens = gr.Slider(
78
+ minimum=1,
79
+ maximum=1000,
80
+ value=250,
81
+ step=1,
82
+ interactive=True,
83
+ label="Max New Tokens",
84
+ )
85
+ top_p = gr.Slider(
86
+ minimum=0.05,
87
+ maximum=1.0,
88
+ value=0.95,
89
+ step=0.05,
90
+ interactive=True,
91
+ label="Top-p (nucleus sampling)",
92
+ )
93
+ top_k = gr.Slider(
94
+ minimum=1,
95
+ maximum=50,
96
+ value=50,
97
+ step=1,
98
+ interactive=True,
99
+ label="Top-k",
100
+ )
101
+ temperature = gr.Slider(
102
+ minimum=0.1,
103
+ maximum=5.0,
104
+ value=0.8,
105
+ step=0.1,
106
+ interactive=True,
107
+ label="Temperature",
108
+ )
109
+
110
+ user_text.submit(
111
+ run_generation,
112
+ [user_text, top_p, temperature, top_k, max_new_tokens],
113
+ model_output,
114
+ )
115
+ button_submit.click(
116
+ run_generation,
117
+ [user_text, top_p, temperature, top_k, max_new_tokens],
118
+ model_output,
119
+ )
120
+
121
+ demo.queue(max_size=32).launch(enable_queue=True, server_name="0.0.0.0")
122
+ # For local use:
123
+ # demo.launch(server_name="0.0.0.0")