File size: 7,218 Bytes
79186fd cf229a5 327f48b 79186fd 327f48b ecf9cdb 79186fd d9961a5 d5b26cf a5749b0 79186fd 9c8d7cc af50430 4293cb3 af50430 79186fd d9961a5 d5b26cf af50430 d9961a5 a5749b0 cf229a5 dd9a62d cf229a5 af50430 cf229a5 850df56 af50430 cf229a5 7aa7dac cf229a5 d5b26cf a5749b0 d5b26cf dd9a62d d5b26cf 6330f37 d5b26cf 850df56 d5b26cf 6330f37 d5b26cf 6330f37 d5b26cf 7aa7dac a5749b0 d5b26cf 79186fd c9aa76d 1ff178d 35bee3d 46123ec c9aa76d 35bee3d c9aa76d d26d605 2803042 c9aa76d d26d605 c9aa76d a56aa3f 79186fd c9aa76d a56aa3f 86cc435 b735702 a56aa3f 86cc435 a56aa3f a5749b0 86cc435 2f4c75d a5749b0 86cc435 37e7071 86cc435 a56aa3f 86cc435 a56aa3f 86cc435 a5749b0 79186fd a5749b0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 |
import gradio as gr
from huggingface_hub import InferenceClient
import torch
from transformers import pipeline
import os
# Inference client setup with token from environment
token = os.getenv('HF_TOKEN')
client = InferenceClient(model="HuggingFaceH4/zephyr-7b-beta", token=token)
pipe = pipeline("text-generation", "microsoft/Phi-3-mini-4k-instruct", torch_dtype=torch.bfloat16, device_map="auto")
# Global flag to handle cancellation
stop_inference = False
def respond(
message,
history: list[tuple[str, str]],
system_message="You are a friendly Chatbot.",
max_tokens=512,
temperature=1.5,
top_p=0.95,
use_local_model=False,
):
global stop_inference
stop_inference = False # Reset cancellation flag
# Initialize history if it's None
if history is None:
history = []
if use_local_model:
# local inference
messages = [{"role": "system", "content": system_message}]
for val in history:
if val[0]:
messages.append({"role": "user", "content": val[0]})
if val[1]:
messages.append({"role": "assistant", "content": val[1]})
messages.append({"role": "user", "content": message})
response = ""
for output in pipe(
messages,
max_new_tokens=max_tokens,
temperature=temperature,
do_sample=True,
top_p=top_p,
):
if stop_inference:
response = "Inference cancelled."
yield history + [(message, response)]
return
token = output['generated_text'][-1]['content']
response += token
yield history + [(message, response)] # Yield history + new response
else:
# API-based inference
messages = [{"role": "system", "content": system_message}]
for val in history:
if val[0]:
messages.append({"role": "user", "content": val[0]})
if val[1]:
messages.append({"role": "assistant", "content": val[1]})
messages.append({"role": "user", "content": message})
response = ""
for message_chunk in client.chat_completion(
messages,
max_tokens=max_tokens,
stream=True,
temperature=temperature,
top_p=top_p,
):
if stop_inference:
response = "Inference cancelled."
yield history + [(message, response)]
return
if stop_inference:
response = "Inference cancelled."
break
token = message_chunk.choices[0].delta.content
response += token
yield history + [(message, response)] # Yield history + new response
def cancel_inference():
global stop_inference
stop_inference = True
# Custom CSS for a fancy look
custom_css = """
#main-container {
background: #cdebc5;
font-family: 'Comic Neue', sans-serif;
}
.gradio-container {
max-width: 700px;
margin: 0 auto;
padding: 20px;
background: #cdebc5;
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1);
border-radius: 10px;
}
.gr-button {
background-color: #a7e0fd;
color: light blue;
border: none;
border-radius: 5px;
padding: 10px 20px;
cursor: pointer;
transition: background-color 0.3s ease;
}
.gr-button:hover {
background-color: #45a049;
}
.gr-slider input {
color: #4CAF50;
}
.gr-chat {
font-size: 16px;
}
#title {
text-align: center;
font-size: 2em;
margin-bottom: 20px;
color: #a7e0fd;
}
#school_ai_image {
width: 150px;
height: auto;
display: block;
margin: 0 auto;
}
"""
# Define system messages for each level
def update_system_message(level):
if level == "Elementary School":
return "Your name is Wormington. You are a friendly Chatbot that can help answer questions from elementary school students. Please respond with the vocabulary that a seven-year-old can understand."
elif level == "Middle School":
return "Your name is Wormington. You are a friendly Chatbot that can help answer questions from middle school students. Please respond at a level that middle schoolers can understand."
elif level == "High School":
return "Your name is Wormington. You are a friendly Chatbot that can help answer questions from high school students. Please respond at a level that a high schooler can understand."
elif level == "College":
return "Your name is Wormington. You are a friendly Chatbot that can help answer questions from college students. Please respond using very advanced, college-level vocabulary."
# Define interface
with gr.Blocks(css=custom_css) as demo:
gr.Markdown("<h2 style='text-align: center;'>πβοΈ School AI Chatbot βοΈπ</h2>")
gr.Image("wormington_headshot.jpg", elem_id="school_ai_image", show_label=False, interactive=False)
gr.Markdown("<h1 style= 'text-align: center;'>Interact with Wormington Scholar π by selecting the appropriate level below.")
with gr.Row():
elementary_button = gr.Button("Elementary School", elem_id="elementary", variant="primary")
middle_button = gr.Button("Middle School", elem_id="middle", variant="primary")
high_button = gr.Button("High School", elem_id="high", variant="primary")
college_button = gr.Button("College", elem_id="college", variant="primary")
# Display area for the selected system message
system_message_display = gr.Textbox(label="System Message", value="", interactive=False)
# Update the system message when a button is clicked
elementary_button.click(fn=lambda: update_system_message("Elementary School"), inputs=None, outputs=system_message_display)
middle_button.click(fn=lambda: update_system_message("Middle School"), inputs=None, outputs=system_message_display)
high_button.click(fn=lambda: update_system_message("High School"), inputs=None, outputs=system_message_display)
college_button.click(fn=lambda: update_system_message("College"), inputs=None, outputs=system_message_display)
with gr.Row():
use_local_model = gr.Checkbox(label="Use Local Model", value=False)
with gr.Row():
max_tokens = gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens")
temperature = gr.Slider(minimum=0.5, maximum=4.0, value=1.2, step=0.1, label="Temperature")
top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)")
chat_history = gr.Chatbot(label="Chat")
user_input = gr.Textbox(show_label=False, placeholder="Wormington would love to answer your questions. Type them here:")
cancel_button = gr.Button("Cancel Inference", variant="danger")
# Adjusted to ensure history is maintained and passed correctly
user_input.submit(respond, [user_input, chat_history, system_message_display, max_tokens, temperature, top_p, use_local_model], chat_history)
cancel_button.click(cancel_inference)
if __name__ == "__main__":
demo.launch(share=False) # Remove share=True because it's not supported on HF Spaces |