File size: 8,447 Bytes
31b8d70 79186fd cf229a5 f567464 816324f 79186fd cda097b f567464 816324f f9dbf68 fef6b7e 79186fd 918a9fb d5b26cf 79186fd 9c8d7cc af50430 4293cb3 af50430 79186fd d9961a5 918a9fb d5b26cf 918a9fb af50430 d9961a5 918a9fb cf229a5 dd9a62d cf229a5 af50430 cf229a5 850df56 af50430 cf229a5 918a9fb cf229a5 d5b26cf 918a9fb d5b26cf dd9a62d d5b26cf 6330f37 d5b26cf f9dbf68 d5b26cf 850df56 6330f37 d5b26cf 918a9fb a5749b0 d5b26cf 79186fd 31b8d70 c9aa76d 1ff178d 35bee3d 46123ec c9aa76d 35bee3d c9aa76d d26d605 2803042 c9aa76d 31b8d70 c9aa76d 79186fd c9aa76d f567464 31b8d70 f567464 86cc435 b735702 51c26e3 a56aa3f 21ebf11 f567464 31b8d70 a5749b0 86cc435 a5749b0 86cc435 37e7071 86cc435 f567464 86cc435 31b8d70 86cc435 918a9fb f567464 86cc435 31b8d70 918a9fb 79186fd 31b8d70 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 |
import gradio as gr
from huggingface_hub import InferenceClient
import torch
from transformers import pipeline
import os
import sys
if len(sys.argv) > 1:
token = sys.argv[1]
else:
token = os.getenv('HF_TOKEN')
print(token)
# Inference client setup with token from environment
# token = os.getenv('HF_TOKEN')
client = InferenceClient(model="HuggingFaceH4/zephyr-7b-alpha", token=token)
# pipe = pipeline("text-generation", "TinyLlama/TinyLlama_v1.1", torch_dtype=torch.bfloat16, device_map="auto")
pipe = pipeline("text-generation", "microsoft/Phi-3-mini-4k-instruct", torch_dtype=torch.bfloat16, device_map="auto")
# Global flag to handle cancellation
stop_inference = False
def respond(
message,
history: list[tuple[str, str]],
system_message="You are a friendly Chatbot.",
max_tokens=512,
temperature=1.5,
top_p=0.95,
use_local_model=False,
):
global stop_inference
stop_inference = False # Reset cancellation flag
# Initialize history if it's None
if history is None:
history = []
if use_local_model:
# local inference
messages = [{"role": "system", "content": system_message}]
for val in history:
if val[0]:
messages.append({"role": "user", "content": val[0]})
if val[1]:
messages.append({"role": "assistant", "content": val[1]})
messages.append({"role": "user", "content": message})
response = ""
for output in pipe(
messages,
max_new_tokens=max_tokens,
temperature=temperature,
do_sample=True,
top_p=top_p,
):
if stop_inference:
response = "Inference cancelled."
yield history + [(message, response)]
return
token = output['generated_text'][-1]['content']
response += token
yield history + [(message, response)] # Yield history + new response
else:
# API-based inference
messages = [{"role": "system", "content": system_message}]
for val in history:
if val[0]:
messages.append({"role": "user", "content": val[0]})
if val[1]:
messages.append({"role": "assistant", "content": val[1]})
messages.append({"role": "user", "content": message})
response = ""
for message_chunk in client.chat_completion(
messages,
max_tokens=max_tokens,
stream=False,
temperature=temperature,
top_p=top_p,
):
if stop_inference:
response = "Inference cancelled."
yield history + [(message, response)]
return
token = message_chunk.choices[0].delta.content
response += token
yield history + [(message, response)] # Yield history + new response
def cancel_inference():
global stop_inference
stop_inference = True
# Custom CSS to disable buttons visually
custom_css = """
#main-container {
background: #cdebc5;
font-family: 'Comic Neue', sans-serif;
}
.gradio-container {
max-width: 700px;
margin: 0 auto;
padding: 20px;
background: #cdebc5;
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1);
border-radius: 10px;
}
.gr-button {
background-color: #a7e0fd;
color: light blue;
border: none;
border-radius: 5px;
padding: 10px 20px;
cursor: pointer;
transition: background-color 0.3s ease;
}
.gr-button:disabled {
background-color: grey;
cursor: not-allowed;
}
"""
# Define system messages for each level
def update_system_message(level):
if level == "Elementary School":
return "Your name is Wormington. You are a friendly Chatbot that can help answer questions from elementary school students. Please respond with the vocabulary that a seven-year-old can understand."
elif level == "Middle School":
return "Your name is Wormington. You are a friendly Chatbot that can help answer questions from middle school students. Please respond at a level that middle schoolers can understand."
elif level == "High School":
return "Your name is Wormington. You are a friendly Chatbot that can help answer questions from high school students. Please respond at a level that a high schooler can understand."
elif level == "College":
return "Your name is Wormington. You are a friendly Chatbot that can help answer questions from college students. Please respond using very advanced, college-level vocabulary."
# Disable all buttons after one is clicked
def disable_buttons_and_update_message(level):
system_message = update_system_message(level)
# Update button states to disabled
return system_message, gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=False)
# Restart function to refresh the app
def restart_chatbot():
# Reset buttons and clear system message display
return gr.update(value="", interactive=True), gr.update(interactive=True), gr.update(interactive=True), gr.update(interactive=True), gr.update(interactive=True)
# Define interface
with gr.Blocks(css=custom_css) as demo:
gr.Markdown("<h2 style='text-align: center;'>πβοΈ School AI Chatbot βοΈπ</h2>")
gr.Markdown("<h1 style= 'text-align: center;'>Interact with Wormington Scholar π by selecting the appropriate level below!</h1>")
with gr.Row():
elementary_button = gr.Button("Elementary School", elem_id="elementary", variant="primary")
middle_button = gr.Button("Middle School", elem_id="middle", variant="primary")
high_button = gr.Button("High School", elem_id="high", variant="primary")
college_button = gr.Button("College", elem_id="college", variant="primary")
# Display area for the selected system message
system_message_display = gr.Textbox(label="System Message", value="", interactive=False)
# Disable buttons and update the system message when a button is clicked
elementary_button.click(fn=lambda: disable_buttons_and_update_message("Elementary School"),
inputs=None,
outputs=[system_message_display, elementary_button, middle_button, high_button, college_button])
middle_button.click(fn=lambda: disable_buttons_and_update_message("Middle School"),
inputs=None,
outputs=[system_message_display, elementary_button, middle_button, high_button, college_button])
high_button.click(fn=lambda: disable_buttons_and_update_message("High School"),
inputs=None,
outputs=[system_message_display, elementary_button, middle_button, high_button, college_button])
college_button.click(fn=lambda: disable_buttons_and_update_message("College"),
inputs=None,
outputs=[system_message_display, elementary_button, middle_button, high_button, college_button])
with gr.Row():
use_local_model = gr.Checkbox(label="Use Local Model", value=False)
with gr.Row():
max_tokens = gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens")
temperature = gr.Slider(minimum=0.5, maximum=4.0, value=1.2, step=0.1, label="Temperature")
top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)")
chat_history = gr.Chatbot(label="Chat")
user_input = gr.Textbox(show_label=False, placeholder="Wormington would love to answer your questions. Type them here:")
cancel_button = gr.Button("Cancel Inference", variant="danger")
restart_button = gr.Button("Restart Chatbot", variant="secondary")
# Adjusted to ensure history is maintained and passed correctly
user_input.submit(respond, [user_input, chat_history, system_message_display, max_tokens, temperature, top_p, use_local_model], chat_history)
cancel_button.click(cancel_inference)
# Reset the buttons when the "Restart Chatbot" button is clicked
restart_button.click(fn=restart_chatbot,
inputs=None,
outputs=[system_message_display, elementary_button, middle_button, high_button, college_button])
if __name__ == "__main__":
demo.launch(share=False)
|