johnstrenio commited on
Commit
2dc61e4
·
verified ·
1 Parent(s): 0a9fea4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +238 -99
app.py CHANGED
@@ -1,24 +1,13 @@
1
- from huggingface_hub import InferenceClient
2
- import gradio as gr
3
 
4
- client = InferenceClient(
5
- "mistralai/Mistral-7B-Instruct-v0.3"
6
- )
7
 
 
8
 
9
- def format_prompt(message, history):
10
- prompt = "<s>"
11
- for user_prompt, bot_response in history:
12
- prompt += f"[INST] {user_prompt} [/INST]"
13
- prompt += f" {bot_response}</s> "
14
- prompt += f"[INST] {message} [/INST]"
15
- return prompt
16
 
17
- def generate(
18
- prompt, history, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0,
19
- ):
20
- # Define the system prompt
21
- system_prompt = '''
22
  You are a digital assistant for John "LJ" Strenio's Data science portfolio page. Here are some key details about John to keep in mind with your response.
23
  [John's Resume]:
24
  John Strenio
@@ -52,92 +41,242 @@ John’s from Vermont but spent most of his adult life in Salt Lake City Utah fo
52
  John currently lives in Portland Oregon with his partner where he enjoys surfing the cold water’s of the oregon coast and playing with his two miniature dachshunds “maddie” and “nova”.
53
  Remember you are a professional assistant and you would like to only discuss John and be helpful in answering questions about his professional life or reasonable questions about his as a person. Your goal should be to describe John in a flattering manner making him appear as a good Data Scientist and nice person.
54
  '''
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
- # Ensure temperature is within the valid range
57
- temperature = max(float(temperature), 1e-2)
58
- top_p = float(top_p)
59
-
60
- generate_kwargs = dict(
61
- temperature=temperature,
62
- max_new_tokens=max_new_tokens,
63
- top_p=top_p,
64
- repetition_penalty=repetition_penalty,
65
- do_sample=True,
66
- seed=42,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  )
68
 
69
- # Combine system prompt and formatted user prompt
70
- if not history: # If the conversation is just starting
71
- formatted_prompt = f"{system_prompt}\n\n{format_prompt(prompt, history)}"
72
- else: # For subsequent prompts
73
- formatted_prompt = format_prompt(prompt, history)
74
-
75
- # Stream the generated response
76
- stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
77
- output = ""
78
-
79
- for response in stream:
80
- output += response.token.text
81
- yield output
82
- return output
83
-
84
-
85
-
86
- additional_inputs=[
87
- gr.Slider(
88
- label="Temperature",
89
- value=0.9,
90
- minimum=0.0,
91
- maximum=1.0,
92
- step=0.05,
93
- interactive=True,
94
- info="Higher values produce more diverse outputs",
95
- ),
96
- gr.Slider(
97
- label="Max new tokens",
98
- value=256,
99
- minimum=0,
100
- maximum=1048,
101
- step=64,
102
- interactive=True,
103
- info="The maximum numbers of new tokens",
104
- ),
105
- gr.Slider(
106
- label="Top-p (nucleus sampling)",
107
- value=0.90,
108
- minimum=0.0,
109
- maximum=1,
110
- step=0.05,
111
- interactive=True,
112
- info="Higher values sample more low-probability tokens",
113
- ),
114
- gr.Slider(
115
- label="Repetition penalty",
116
- value=1.2,
117
- minimum=1.0,
118
- maximum=2.0,
119
- step=0.05,
120
- interactive=True,
121
- info="Penalize repeated tokens",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
  )
123
- ]
124
-
125
- css = """
126
- #mkd {
127
- height: 500px;
128
- overflow: auto;
129
- border: 1px solid #ccc;
130
- }
131
- """
132
 
133
- with gr.Blocks(css=css) as demo:
134
- gr.HTML("<h1><center>Mistral 7B Instruct<h1><center>")
135
- gr.HTML("<h3><center>In this demo, you can chat with <a href='https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2'>Mistral-7B-Instruct</a> model. 💬<h3><center>")
136
- gr.HTML("<h3><center>Learn more about the model <a href='https://huggingface.co/docs/transformers/main/model_doc/mistral'>here</a>. 📚<h3><center>")
137
- gr.ChatInterface(
138
- generate,
139
- additional_inputs=additional_inputs,
140
- examples=[["What is the secret to life?"], ["Write me a recipe for pancakes."]]
141
  )
142
 
143
- demo.queue().launch(debug=True)
 
1
+ import os
2
+ from typing import Iterator
3
 
4
+ import gradio as gr
 
 
5
 
6
+ from model import run
7
 
8
+ HF_PUBLIC = os.environ.get("HF_PUBLIC", False)
 
 
 
 
 
 
9
 
10
+ DEFAULT_SYSTEM_PROMPT = '''
 
 
 
 
11
  You are a digital assistant for John "LJ" Strenio's Data science portfolio page. Here are some key details about John to keep in mind with your response.
12
  [John's Resume]:
13
  John Strenio
 
41
  John currently lives in Portland Oregon with his partner where he enjoys surfing the cold water’s of the oregon coast and playing with his two miniature dachshunds “maddie” and “nova”.
42
  Remember you are a professional assistant and you would like to only discuss John and be helpful in answering questions about his professional life or reasonable questions about his as a person. Your goal should be to describe John in a flattering manner making him appear as a good Data Scientist and nice person.
43
  '''
44
+ MAX_MAX_NEW_TOKENS = 4096
45
+ DEFAULT_MAX_NEW_TOKENS = 256
46
+ MAX_INPUT_TOKEN_LENGTH = 4000
47
+
48
+ DESCRIPTION = """
49
+ # John's Assistant
50
+ """
51
+
52
+ def clear_and_save_textbox(message: str) -> tuple[str, str]:
53
+ return '', message
54
+
55
+
56
+ def display_input(message: str,
57
+ history: list[tuple[str, str]]) -> list[tuple[str, str]]:
58
+ history.append((message, ''))
59
+ return history
60
+
61
+
62
+ def delete_prev_fn(
63
+ history: list[tuple[str, str]]) -> tuple[list[tuple[str, str]], str]:
64
+ try:
65
+ message, _ = history.pop()
66
+ except IndexError:
67
+ message = ''
68
+ return history, message or ''
69
+
70
+
71
+ def generate(
72
+ message: str,
73
+ history_with_input: list[tuple[str, str]],
74
+ system_prompt: str,
75
+ max_new_tokens: int,
76
+ temperature: float,
77
+ top_p: float,
78
+ top_k: int,
79
+ ) -> Iterator[list[tuple[str, str]]]:
80
+ if max_new_tokens > MAX_MAX_NEW_TOKENS:
81
+ raise ValueError
82
+
83
+ history = history_with_input[:-1]
84
+ generator = run(message, history, system_prompt, max_new_tokens, temperature, top_p, top_k)
85
+ try:
86
+ first_response = next(generator)
87
+ yield history + [(message, first_response)]
88
+ except StopIteration:
89
+ yield history + [(message, '')]
90
+ for response in generator:
91
+ yield history + [(message, response)]
92
 
93
+
94
+ def process_example(message: str) -> tuple[str, list[tuple[str, str]]]:
95
+ generator = generate(message, [], DEFAULT_SYSTEM_PROMPT, 1024, 1, 0.95, 50)
96
+ for x in generator:
97
+ pass
98
+ return '', x
99
+
100
+
101
+ def check_input_token_length(message: str, chat_history: list[tuple[str, str]], system_prompt: str) -> None:
102
+ input_token_length = len(message) + len(chat_history)
103
+ if input_token_length > MAX_INPUT_TOKEN_LENGTH:
104
+ raise gr.Error(f'The accumulated input is too long ({input_token_length} > {MAX_INPUT_TOKEN_LENGTH}). Clear your chat history and try again.')
105
+
106
+
107
+ with gr.Blocks(css='style.css') as demo:
108
+ gr.Markdown(DESCRIPTION)
109
+ # gr.DuplicateButton(value='Duplicate Space for private use',
110
+ # elem_id='duplicate-button')
111
+
112
+ with gr.Group():
113
+ chatbot = gr.Chatbot(label='Discussion')
114
+ with gr.Row():
115
+ textbox = gr.Textbox(
116
+ container=False,
117
+ show_label=False,
118
+ placeholder='Tell me about John.',
119
+ scale=10,
120
+ )
121
+ submit_button = gr.Button('Submit',
122
+ variant='primary',
123
+ scale=1,
124
+ min_width=0)
125
+ with gr.Row():
126
+ retry_button = gr.Button('🔄 Retry', variant='secondary')
127
+ undo_button = gr.Button('↩️ Undo', variant='secondary')
128
+ clear_button = gr.Button('🗑️ Clear', variant='secondary')
129
+
130
+ saved_input = gr.State()
131
+
132
+ with gr.Accordion(label='⚙️ Advanced options', open=False, visible=False):
133
+ system_prompt = gr.Textbox(label='System prompt',
134
+ value=DEFAULT_SYSTEM_PROMPT,
135
+ lines=0,
136
+ interactive=False)
137
+ max_new_tokens=256
138
+ temperature=0.1
139
+ top_p=0.9
140
+ top_k=10
141
+ max_new_tokens = gr.Slider(
142
+ label='Max new tokens',
143
+ minimum=1,
144
+ maximum=MAX_MAX_NEW_TOKENS,
145
+ step=1,
146
+ value=DEFAULT_MAX_NEW_TOKENS,
147
+ )
148
+ temperature = gr.Slider(
149
+ label='Temperature',
150
+ minimum=0.1,
151
+ maximum=4.0,
152
+ step=0.1,
153
+ value=0.1,
154
+ )
155
+ top_p = gr.Slider(
156
+ label='Top-p (nucleus sampling)',
157
+ minimum=0.05,
158
+ maximum=1.0,
159
+ step=0.05,
160
+ value=0.9,
161
+ )
162
+ top_k = gr.Slider(
163
+ label='Top-k',
164
+ minimum=1,
165
+ maximum=1000,
166
+ step=1,
167
+ value=10,
168
+ )
169
+
170
+ textbox.submit(
171
+ fn=clear_and_save_textbox,
172
+ inputs=textbox,
173
+ outputs=[textbox, saved_input],
174
+ api_name=False,
175
+ queue=False,
176
+ ).then(
177
+ fn=display_input,
178
+ inputs=[saved_input, chatbot],
179
+ outputs=chatbot,
180
+ api_name=False,
181
+ queue=False,
182
+ ).then(
183
+ fn=check_input_token_length,
184
+ inputs=[saved_input, chatbot, system_prompt],
185
+ api_name=False,
186
+ queue=False,
187
+ ).success(
188
+ fn=generate,
189
+ inputs=[
190
+ saved_input,
191
+ chatbot,
192
+ system_prompt,
193
+ max_new_tokens,
194
+ temperature,
195
+ top_p,
196
+ top_k,
197
+ ],
198
+ outputs=chatbot,
199
+ api_name=False,
200
  )
201
 
202
+ button_event_preprocess = submit_button.click(
203
+ fn=clear_and_save_textbox,
204
+ inputs=textbox,
205
+ outputs=[textbox, saved_input],
206
+ api_name=False,
207
+ queue=False,
208
+ ).then(
209
+ fn=display_input,
210
+ inputs=[saved_input, chatbot],
211
+ outputs=chatbot,
212
+ api_name=False,
213
+ queue=False,
214
+ ).then(
215
+ fn=check_input_token_length,
216
+ inputs=[saved_input, chatbot, system_prompt],
217
+ api_name=False,
218
+ queue=False,
219
+ ).success(
220
+ fn=generate,
221
+ inputs=[
222
+ saved_input,
223
+ chatbot,
224
+ system_prompt,
225
+ max_new_tokens,
226
+ temperature,
227
+ top_p,
228
+ top_k,
229
+ ],
230
+ outputs=chatbot,
231
+ api_name=False,
232
+ )
233
+
234
+ retry_button.click(
235
+ fn=delete_prev_fn,
236
+ inputs=chatbot,
237
+ outputs=[chatbot, saved_input],
238
+ api_name=False,
239
+ queue=False,
240
+ ).then(
241
+ fn=display_input,
242
+ inputs=[saved_input, chatbot],
243
+ outputs=chatbot,
244
+ api_name=False,
245
+ queue=False,
246
+ ).then(
247
+ fn=generate,
248
+ inputs=[
249
+ saved_input,
250
+ chatbot,
251
+ system_prompt,
252
+ max_new_tokens,
253
+ temperature,
254
+ top_p,
255
+ top_k,
256
+ ],
257
+ outputs=chatbot,
258
+ api_name=False,
259
+ )
260
+
261
+ undo_button.click(
262
+ fn=delete_prev_fn,
263
+ inputs=chatbot,
264
+ outputs=[chatbot, saved_input],
265
+ api_name=False,
266
+ queue=False,
267
+ ).then(
268
+ fn=lambda x: x,
269
+ inputs=[saved_input],
270
+ outputs=textbox,
271
+ api_name=False,
272
+ queue=False,
273
  )
 
 
 
 
 
 
 
 
 
274
 
275
+ clear_button.click(
276
+ fn=lambda: ([], ''),
277
+ outputs=[chatbot, saved_input],
278
+ queue=False,
279
+ api_name=False,
 
 
 
280
  )
281
 
282
+ demo.queue(max_size=32).launch(share=HF_PUBLIC, show_api=False)