KingNish commited on
Commit
b34ac00
·
verified ·
1 Parent(s): 6f2ede7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +278 -54
app.py CHANGED
@@ -3,65 +3,289 @@ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStream
3
  import torch
4
  from threading import Thread
5
  import spaces
 
6
 
7
- # Load the model and tokenizer
 
 
 
 
 
 
 
 
 
 
 
8
  model_name = "sarvamai/sarvam-m"
 
9
  tokenizer = AutoTokenizer.from_pretrained(model_name)
10
- model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map="auto")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
  @spaces.GPU
13
- def generate_response(prompt, chat_history):
14
- messages = [{"role": "user", "content": prompt}]
15
- text = tokenizer.apply_chat_template(messages, tokenize=False, enable_thinking=True)
16
-
17
- model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
18
-
19
- # Use TextIteratorStreamer for streaming
20
- streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
21
-
22
- # Conduct text generation with streaming
23
- generation_kwargs = dict(
24
- input_ids=model_inputs.input_ids,
25
- max_new_tokens=8192,
26
- do_sample=True,
27
- temperature=0.7,
28
- streamer=streamer,
29
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
- thread = Thread(target=model.generate, kwargs=generation_kwargs)
32
- thread.start()
33
-
34
- # Initialize variables to capture reasoning content and main content
35
- reasoning_content = ""
36
- content = ""
37
-
38
- for new_text in streamer:
39
- if "</think>" in new_text:
40
- parts = new_text.split("</think>")
41
- reasoning_content = parts[0].rstrip("\n")
42
- content = parts[-1].lstrip("\n").rstrip("</s>")
43
- yield reasoning_content, content
44
- else:
45
- content += new_text
46
- yield reasoning_content, content
47
-
48
- # Create the Gradio interface
49
- with gr.Blocks() as demo:
50
- gr.Markdown("# Sarvam AI Chatbot")
51
- chatbot = gr.Chatbot()
52
- msg = gr.Textbox(label="Your Message")
53
-
54
- def respond(message, chat_history):
55
- chat_history.append((message, ""))
56
- reasoning_content = ""
57
- content = ""
58
- for reasoning_part, content_part in generate_response(message, chat_history):
59
- reasoning_content = reasoning_part
60
- content = content_part
61
- chat_history[-1] = (message, f"{reasoning_content}\n{content}" if reasoning_content else content)
62
- yield chat_history, ""
63
-
64
- msg.submit(respond, [msg, chatbot], [chatbot, msg])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
  if __name__ == "__main__":
67
- demo.launch()
 
3
  import torch
4
  from threading import Thread
5
  import spaces
6
+ import time
7
 
8
+ # For the advanced UI components
9
+ import modelscope_studio.components.antd as antd
10
+ import modelscope_studio.components.antdx as antdx
11
+ import modelscope_studio.components.base as ms
12
+ import modelscope_studio.components.pro as pro
13
+ from modelscope_studio.components.pro.chatbot import (ChatbotBotConfig,
14
+ ChatbotPromptsConfig,
15
+ ChatbotUserConfig,
16
+ ChatbotWelcomeConfig)
17
+
18
+ # --- 1. Load the Hugging Face Model and Tokenizer ---
19
+ # This will be done once when the script starts.
20
  model_name = "sarvamai/sarvam-m"
21
+ print(f"Loading model: {model_name}...")
22
  tokenizer = AutoTokenizer.from_pretrained(model_name)
23
+ # Use float16 for less memory and faster inference on supported GPUs
24
+ model = AutoModelForCausalLM.from_pretrained(
25
+ model_name,
26
+ torch_dtype=torch.float16,
27
+ device_map="auto" # Automatically use the GPU if available
28
+ )
29
+ print("Model loaded successfully.")
30
+
31
+
32
+ # --- 2. Helper and Event Handler Functions ---
33
+
34
+ def format_history_for_sarvam(history: list) -> list:
35
+ """
36
+ Converts the pro.Chatbot's history format to the list of dictionaries
37
+ expected by the tokenizer's apply_chat_template method.
38
+ """
39
+ messages = []
40
+ if not history:
41
+ return messages
42
+
43
+ for item in history:
44
+ role = item.get("role")
45
+ content = item.get("content")
46
+
47
+ if role == "user":
48
+ messages.append({"role": "user", "content": content})
49
+ elif role == "assistant":
50
+ # Assistant content is a list of dicts (for tool/text). We need the final text.
51
+ final_content = ""
52
+ if isinstance(content, list):
53
+ for part in content:
54
+ if part.get("type") == "text":
55
+ final_content = part.get("content", "")
56
+ break
57
+ # In case of error or simple text response
58
+ elif isinstance(content, str):
59
+ final_content = content
60
+
61
+ if final_content:
62
+ messages.append({"role": "assistant", "content": final_content})
63
+
64
+ return messages
65
 
66
  @spaces.GPU
67
+ def submit(sender_value: str, chatbot_value: list):
68
+ """
69
+ The main function to handle user submission. It streams the model's response
70
+ and updates the UI in real-time.
71
+ """
72
+ # Append the new user message to the chat history
73
+ if sender_value:
74
+ chatbot_value.append({
75
+ "role": "user",
76
+ "content": sender_value,
77
+ })
78
+
79
+ # Append a placeholder for the assistant's response, with a loading indicator
80
+ chatbot_value.append({
81
+ "role": "assistant",
82
+ "content": [],
83
+ "loading": True,
84
+ "status": "pending"
85
+ })
86
+
87
+ # Initial UI update: clear the input box and show loading state
88
+ yield {
89
+ sender: gr.update(value=None, loading=True),
90
+ clear_btn: gr.update(disabled=True),
91
+ chatbot: gr.update(value=chatbot_value)
92
+ }
93
+
94
+ try:
95
+ # --- Model Inference ---
96
+ # 1. Format the conversation history for the model
97
+ history_messages = format_history_for_sarvam(chatbot_value)
98
+
99
+ # 2. Apply the chat template, enabling the <think> tag
100
+ prompt_text = tokenizer.apply_chat_template(
101
+ history_messages,
102
+ tokenize=False,
103
+ add_generation_prompt=True,
104
+ enable_thinking=True
105
+ )
106
+ model_inputs = tokenizer([prompt_text], return_tensors="pt").to(model.device)
107
+
108
+ # 3. Set up the streamer
109
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
110
+
111
+ # 4. Start generation in a separate thread
112
+ generation_kwargs = dict(
113
+ input_ids=model_inputs.input_ids,
114
+ max_new_tokens=8192,
115
+ do_sample=True,
116
+ temperature=0.7,
117
+ streamer=streamer,
118
+ )
119
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
120
+ thread.start()
121
+
122
+ # --- Stream and Parse the Response ---
123
+ start_time = time.time()
124
+
125
+ # Prepare the chatbot response structure for thinking + text
126
+ message_content = chatbot_value[-1]["content"]
127
+ message_content.append({
128
+ "copyable": False, "editable": False, "type": "tool", "content": "",
129
+ "options": {"title": "Thinking...", "status": "pending"}
130
+ })
131
+ message_content.append({"type": "text", "content": ""})
132
+
133
+ chatbot_value[-1]["loading"] = False
134
+
135
+ full_response = ""
136
+ thinking_content = ""
137
+ main_content = ""
138
+ thinking_done = False
139
+
140
+ for new_text in streamer:
141
+ full_response += new_text
142
+
143
+ if not thinking_done and "</think>" in full_response:
144
+ thinking_done = True
145
+ try:
146
+ parts = full_response.split("</think>", 1)
147
+ thinking_content = parts[0].split("<think>", 1)[1]
148
+ main_content = parts[1]
149
+
150
+ # Update the "Thinking" block in the UI
151
+ thought_cost_time = "{:.2f}".format(time.time() - start_time)
152
+ message_content[0]["content"] = thinking_content.strip()
153
+ message_content[0]["options"]["title"] = f"End of Thought ({thought_cost_time}s)"
154
+ message_content[0]["options"]["status"] = "done"
155
+
156
+ except IndexError:
157
+ # Malformed tags, treat everything as main content
158
+ main_content = full_response
159
+
160
+ elif not thinking_done:
161
+ # Still in the thinking block
162
+ if full_response.lstrip().startswith("<think>"):
163
+ thinking_content = full_response.lstrip()[len("<think>"):]
164
+ message_content[0]["content"] = thinking_content.strip()
165
+ else:
166
+ # Thinking is done, all new text is main content
167
+ # This ensures the main_content variable is always up-to-date
168
+ main_content = full_response.split("</think>", 1)[1]
169
+
170
+ # Update the main text response in the UI
171
+ message_content[1]["content"] = main_content.lstrip("\n")
172
+
173
+ yield {chatbot: gr.update(value=chatbot_value)}
174
 
175
+ # Finalize the response state
176
+ chatbot_value[-1]["footer"] = "{:.2f}s".format(time.time() - start_time)
177
+ chatbot_value[-1]["status"] = "done"
178
+ yield {
179
+ clear_btn: gr.update(disabled=False),
180
+ sender: gr.update(loading=False),
181
+ chatbot: gr.update(value=chatbot_value),
182
+ }
183
+
184
+ except Exception as e:
185
+ print(f"An error occurred: {e}")
186
+ chatbot_value[-1]["loading"] = False
187
+ chatbot_value[-1]["status"] = "done"
188
+ chatbot_value[-1]["content"] = f"Failed to respond due to an error: {e}"
189
+ yield {
190
+ clear_btn: gr.update(disabled=False),
191
+ sender: gr.update(loading=False),
192
+ chatbot: gr.update(value=chatbot_value),
193
+ }
194
+
195
+ def prompt_select(e: gr.EventData):
196
+ return gr.update(value=e._data["payload"][0]["value"]["description"])
197
+
198
+ def clear():
199
+ return gr.update(value=None)
200
+
201
+ def retry(chatbot_value: list, e: gr.EventData):
202
+ index = e._data["payload"][0]["index"]
203
+ # Remove the last assistant response and user message to retry
204
+ chatbot_value = chatbot_value[:index-1]
205
+
206
+ yield {
207
+ sender: gr.update(loading=True),
208
+ chatbot: gr.update(value=chatbot_value),
209
+ clear_btn: gr.update(disabled=True)
210
+ }
211
+ # Re-run submit with the truncated history
212
+ for chunk in submit(None, chatbot_value):
213
+ yield chunk
214
+
215
+ def cancel(chatbot_value: list):
216
+ # This function is called by Gradio to stop the generator
217
+ if chatbot_value and chatbot_value[-1].get("status") == "pending":
218
+ chatbot_value[-1]["loading"] = False
219
+ chatbot_value[-1]["status"] = "done"
220
+ chatbot_value[-1]["footer"] = "Chat completion paused"
221
+ return {
222
+ chatbot: gr.update(value=chatbot_value),
223
+ sender: gr.update(loading=False),
224
+ clear_btn: gr.update(disabled=False)
225
+ }
226
+
227
+
228
+ # --- 3. Build the Gradio UI ---
229
+
230
+ with gr.Blocks(theme=gr.themes.Default(primary_hue="blue", secondary_hue="blue")) as demo, ms.Application(), antdx.XProvider():
231
+ with antd.Flex(vertical=True, gap="middle"):
232
+ chatbot = pro.Chatbot(
233
+ height=650,
234
+ welcome_config=ChatbotWelcomeConfig(
235
+ variant="borderless",
236
+ icon="https://cdn-avatars.huggingface.co/v1/production/uploads/60270a7c32856987162c641a/umd13GCWVijwTDGZzw3q-.png",
237
+ title=f"Hello, I'm {model_name.split('/')[-1]}",
238
+ description="I can show you my thinking process. How can I help you today?",
239
+ prompts=ChatbotPromptsConfig(
240
+ items=[
241
+ {"label": "Explain a concept", "children": [{"description": "Explain what a Large Language Model is in simple terms."}]},
242
+ {"label": "Help me write", "children": [{"description": "Write a short, futuristic story about AI companions."}]},
243
+ {"label": "Creative Ideas", "children": [{"description": "Give me three creative names for a new coffee shop."}]},
244
+ {"label": "Code generation", "children": [{"description": "Write a python function to find the factorial of a number."}]}
245
+ ]
246
+ )
247
+ ),
248
+ user_config=ChatbotUserConfig(avatar="https://api.dicebear.com/7.x/miniavs/svg?seed=gradio"),
249
+ bot_config=ChatbotBotConfig(
250
+ header=model_name,
251
+ avatar="https://cdn-avatars.huggingface.co/v1/production/uploads/60270a7c32856987162c641a/umd13GCWVijwTDGZzw3q-.png",
252
+ actions=["copy", "retry"]
253
+ ),
254
+ )
255
+
256
+ with antdx.Sender() as sender:
257
+ with ms.Slot("prefix"):
258
+ with antd.Button(value=None, color="default", variant="text") as clear_btn:
259
+ with ms.Slot("icon"):
260
+ antd.Icon("ClearOutlined")
261
+
262
+ # --- Wire up the components and event handlers ---
263
+ clear_btn.click(fn=clear, outputs=[chatbot])
264
+
265
+ submit_event = sender.submit(
266
+ fn=submit,
267
+ inputs=[sender, chatbot],
268
+ outputs=[sender, chatbot, clear_btn]
269
+ )
270
+
271
+ sender.cancel(
272
+ fn=cancel,
273
+ inputs=[chatbot],
274
+ outputs=[chatbot, sender, clear_btn],
275
+ cancels=[submit_event],
276
+ queue=False # Must be False for cancel to work instantly
277
+ )
278
+
279
+ chatbot.retry(
280
+ fn=retry,
281
+ inputs=[chatbot],
282
+ outputs=[sender, chatbot, clear_btn]
283
+ )
284
+
285
+ chatbot.welcome_prompt_select(
286
+ fn=prompt_select,
287
+ outputs=[sender]
288
+ )
289
 
290
  if __name__ == "__main__":
291
+ demo.queue().launch(debug=True)