File size: 8,447 Bytes
31b8d70
79186fd
cf229a5
 
f567464
816324f
 
 
 
 
 
79186fd
cda097b
 
f567464
816324f
f9dbf68
fef6b7e
 
79186fd
918a9fb
d5b26cf
 
79186fd
 
9c8d7cc
af50430
 
4293cb3
af50430
 
79186fd
d9961a5
918a9fb
d5b26cf
918a9fb
af50430
 
 
d9961a5
918a9fb
cf229a5
dd9a62d
 
 
 
 
cf229a5
 
 
af50430
cf229a5
 
 
 
 
 
850df56
 
 
 
af50430
cf229a5
918a9fb
cf229a5
d5b26cf
918a9fb
d5b26cf
dd9a62d
 
 
 
 
d5b26cf
 
 
6330f37
d5b26cf
 
f9dbf68
d5b26cf
 
 
850df56
 
 
 
6330f37
d5b26cf
918a9fb
 
a5749b0
d5b26cf
 
 
79186fd
31b8d70
c9aa76d
1ff178d
35bee3d
46123ec
c9aa76d
 
 
 
 
35bee3d
c9aa76d
 
 
 
d26d605
2803042
c9aa76d
 
 
 
 
 
31b8d70
 
 
c9aa76d
79186fd
c9aa76d
f567464
 
 
 
 
 
 
 
 
 
 
31b8d70
 
 
 
 
 
 
 
 
 
 
f567464
86cc435
b735702
51c26e3
a56aa3f
21ebf11
f567464
 
 
 
 
 
 
 
31b8d70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a5749b0
 
86cc435
a5749b0
86cc435
 
37e7071
86cc435
 
 
 
f567464
86cc435
 
31b8d70
86cc435
918a9fb
f567464
86cc435
 
 
31b8d70
 
 
 
918a9fb
79186fd
31b8d70
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
import gradio as gr 
from huggingface_hub import InferenceClient
import torch
from transformers import pipeline
import os
import sys

if len(sys.argv) > 1:
    token = sys.argv[1]
else:
    token = os.getenv('HF_TOKEN')

print(token)

# Inference client setup with token from environment
# token = os.getenv('HF_TOKEN')
client = InferenceClient(model="HuggingFaceH4/zephyr-7b-alpha", token=token)
# pipe = pipeline("text-generation", "TinyLlama/TinyLlama_v1.1", torch_dtype=torch.bfloat16, device_map="auto")
pipe = pipeline("text-generation", "microsoft/Phi-3-mini-4k-instruct", torch_dtype=torch.bfloat16, device_map="auto")

# Global flag to handle cancellation
stop_inference = False

def respond(
    message,
    history: list[tuple[str, str]],
    system_message="You are a friendly Chatbot.",
    max_tokens=512,
    temperature=1.5,
    top_p=0.95,
    use_local_model=False,
):
    global stop_inference
    stop_inference = False  # Reset cancellation flag

    # Initialize history if it's None
    if history is None:
        history = []

    if use_local_model:
        # local inference 
        messages = [{"role": "system", "content": system_message}]
        for val in history:
            if val[0]:
                messages.append({"role": "user", "content": val[0]})
            if val[1]:
                messages.append({"role": "assistant", "content": val[1]})
        messages.append({"role": "user", "content": message})

        response = ""
        for output in pipe(
            messages,
            max_new_tokens=max_tokens,
            temperature=temperature,
            do_sample=True,
            top_p=top_p,
        ):
            if stop_inference:
                response = "Inference cancelled."
                yield history + [(message, response)]
                return
            token = output['generated_text'][-1]['content']
            response += token
            yield history + [(message, response)]  # Yield history + new response

    else:
        # API-based inference 
        messages = [{"role": "system", "content": system_message}]
        for val in history:
            if val[0]:
                messages.append({"role": "user", "content": val[0]})
            if val[1]:
                messages.append({"role": "assistant", "content": val[1]})
        messages.append({"role": "user", "content": message})

        response = ""
        for message_chunk in client.chat_completion(
            messages,
            max_tokens=max_tokens,
            stream=False,
            temperature=temperature,
            top_p=top_p,
        ):
            if stop_inference:
                response = "Inference cancelled."
                yield history + [(message, response)]
                return
            token = message_chunk.choices[0].delta.content
            response += token
            yield history + [(message, response)]  # Yield history + new response


def cancel_inference():
    global stop_inference
    stop_inference = True

# Custom CSS to disable buttons visually
custom_css = """
#main-container {
    background: #cdebc5;
    font-family: 'Comic Neue', sans-serif;
}
.gradio-container {
    max-width: 700px;
    margin: 0 auto;
    padding: 20px;
    background: #cdebc5;
    box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1);
    border-radius: 10px;
}
.gr-button {
    background-color: #a7e0fd;
    color: light blue;
    border: none;
    border-radius: 5px;
    padding: 10px 20px;
    cursor: pointer;
    transition: background-color 0.3s ease;
}
.gr-button:disabled {
    background-color: grey;
    cursor: not-allowed;
}
"""

# Define system messages for each level
def update_system_message(level):
    if level == "Elementary School":
        return "Your name is Wormington. You are a friendly Chatbot that can help answer questions from elementary school students. Please respond with the vocabulary that a seven-year-old can understand."
    elif level == "Middle School":
        return "Your name is Wormington. You are a friendly Chatbot that can help answer questions from middle school students. Please respond at a level that middle schoolers can understand."
    elif level == "High School":
        return "Your name is Wormington. You are a friendly Chatbot that can help answer questions from high school students. Please respond at a level that a high schooler can understand."
    elif level == "College":
        return "Your name is Wormington. You are a friendly Chatbot that can help answer questions from college students. Please respond using very advanced, college-level vocabulary."

# Disable all buttons after one is clicked
def disable_buttons_and_update_message(level):
    system_message = update_system_message(level)
    # Update button states to disabled
    return system_message, gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=False)

# Restart function to refresh the app
def restart_chatbot():
    # Reset buttons and clear system message display
    return gr.update(value="", interactive=True), gr.update(interactive=True), gr.update(interactive=True), gr.update(interactive=True), gr.update(interactive=True)

# Define interface
with gr.Blocks(css=custom_css) as demo:
    gr.Markdown("<h2 style='text-align: center;'>🍎✏️ School AI Chatbot ✏️🍎</h2>")
    gr.Markdown("<h1 style= 'text-align: center;'>Interact with Wormington Scholar πŸ› by selecting the appropriate level below!</h1>")

    with gr.Row():
        elementary_button = gr.Button("Elementary School", elem_id="elementary", variant="primary")
        middle_button = gr.Button("Middle School", elem_id="middle", variant="primary")
        high_button = gr.Button("High School", elem_id="high", variant="primary")
        college_button = gr.Button("College", elem_id="college", variant="primary")

    # Display area for the selected system message
    system_message_display = gr.Textbox(label="System Message", value="", interactive=False)

    # Disable buttons and update the system message when a button is clicked
    elementary_button.click(fn=lambda: disable_buttons_and_update_message("Elementary School"), 
                            inputs=None, 
                            outputs=[system_message_display, elementary_button, middle_button, high_button, college_button])
    
    middle_button.click(fn=lambda: disable_buttons_and_update_message("Middle School"), 
                        inputs=None, 
                        outputs=[system_message_display, elementary_button, middle_button, high_button, college_button])
    
    high_button.click(fn=lambda: disable_buttons_and_update_message("High School"), 
                      inputs=None, 
                      outputs=[system_message_display, elementary_button, middle_button, high_button, college_button])
    
    college_button.click(fn=lambda: disable_buttons_and_update_message("College"), 
                         inputs=None, 
                         outputs=[system_message_display, elementary_button, middle_button, high_button, college_button])

    with gr.Row():  
        use_local_model = gr.Checkbox(label="Use Local Model", value=False)

    with gr.Row():
        max_tokens = gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens")
        temperature = gr.Slider(minimum=0.5, maximum=4.0, value=1.2, step=0.1, label="Temperature")
        top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)")

    chat_history = gr.Chatbot(label="Chat")

    user_input = gr.Textbox(show_label=False, placeholder="Wormington would love to answer your questions. Type them here:")

    cancel_button = gr.Button("Cancel Inference", variant="danger")
    restart_button = gr.Button("Restart Chatbot", variant="secondary")

    # Adjusted to ensure history is maintained and passed correctly
    user_input.submit(respond, [user_input, chat_history, system_message_display, max_tokens, temperature, top_p, use_local_model], chat_history)

    cancel_button.click(cancel_inference)

    # Reset the buttons when the "Restart Chatbot" button is clicked
    restart_button.click(fn=restart_chatbot, 
                         inputs=None, 
                         outputs=[system_message_display, elementary_button, middle_button, high_button, college_button])

if __name__ == "__main__":
    demo.launch(share=False)