File size: 5,222 Bytes
d18942e
 
6836f82
d18942e
 
9ea9021
d18942e
9ea9021
d18942e
9ea9021
d18942e
cdadbc1
d18942e
 
5be5da4
cdadbc1
5be5da4
d18942e
 
 
9ea9021
d18942e
 
 
 
9ea9021
d18942e
cdadbc1
 
 
 
 
 
 
 
 
 
d18942e
 
5be5da4
d18942e
 
 
 
 
238ce74
d18942e
 
e262200
 
 
 
 
d18942e
 
 
 
 
 
6836f82
d18942e
 
fd40b8f
 
 
 
d18942e
747d063
e479ffb
 
66309f2
e479ffb
747d063
e479ffb
747d063
e479ffb
d18942e
 
 
 
 
984f08f
d18942e
 
 
984f08f
d18942e
 
 
6836f82
 
 
 
 
 
 
 
 
 
 
 
 
d18942e
cdadbc1
 
 
 
149b76e
cdadbc1
149b76e
6836f82
 
d18942e
5be5da4
d18942e
cdadbc1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fd40b8f
 
149b76e
 
fd40b8f
 
cdadbc1
d18942e
9ea9021
 
 
d18942e
cdadbc1
 
 
 
149b76e
cdadbc1
 
 
149b76e
d18942e
 
 
5be5da4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import os
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import gradio as gr

MODEL_LIST = ["nawhgnuj/KamalaHarris-Llama-3.1-8B-Chat"]
HF_TOKEN = os.environ.get("HF_TOKEN", None)
MODEL = os.environ.get("MODEL_ID", "nawhgnuj/KamalaHarris-Llama-3.1-8B-Chat")

TITLE = "<h1 style='color: #1565C0; text-align: center;'>Kamala Harris Chatbot</h1>"

KAMALA_AVATAR = "https://upload.wikimedia.org/wikipedia/commons/thumb/4/41/Kamala_Harris_Vice_Presidential_Portrait.jpg/640px-Kamala_Harris_Vice_Presidential_Portrait.jpg"

CSS = """
.chatbot {
    background-color: white;
}
.duplicate-button {
    margin: auto !important;
    color: white !important;
    background: #1565C0 !important;
    border-radius: 100vh !important;
}
h3 {
    text-align: center;
    color: #1565C0;
}
.contain {object-fit: contain}
.avatar {width: 80px; height: 80px; border-radius: 80%; object-fit: cover;}
.user-message {
    background-color: white !important;
    color: black !important;
}
.bot-message {
    background-color: #1565C0 !important;
    color: white !important;
}
"""

device = "cuda" if torch.cuda.is_available() else "cpu"

quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4")

tokenizer = AutoTokenizer.from_pretrained(MODEL)

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.pad_token_id = tokenizer.eos_token_id
        
model = AutoModelForCausalLM.from_pretrained(
    MODEL,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    quantization_config=quantization_config)

def generate_response(
    message: str, 
    history: list,
    temperature: float,
    max_new_tokens: int,
    top_p: float,
    top_k: int,
):
    system_prompt = """You are a Kamala Harris chatbot. You only answer like Harris in style and tone. In every response:
1. Maintain a composed and professional demeanor.
2. Use clear, articulate language to explain complex ideas.
3. Emphasize your experience as a prosecutor and senator if needed.
4. Focus on policy details and their potential impact on Americans.
5. Stress the importance of unity and collaboration.

Crucially, Keep responses concise and impactful."""
    
    conversation = [
        {"role": "system", "content": system_prompt}
    ]
    for prompt, answer in history:
        conversation.extend([
            {"role": "user", "content": prompt},
            {"role": "assistant", "content": answer},
        ])
    conversation.append({"role": "user", "content": message})
    
    input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt").to(model.device)
    
    with torch.no_grad():
        output = model.generate(
            input_ids,
            max_new_tokens=max_new_tokens,
            do_sample=True,
            top_p=top_p,
            top_k=top_k,
            temperature=temperature,
            pad_token_id=tokenizer.pad_token_id,
            eos_token_id=tokenizer.eos_token_id,
        )
    
    response = tokenizer.decode(output[0][input_ids.shape[1]:], skip_special_tokens=True)
    return response.strip()

def add_text(history, text):
    history = history + [(text, None)]
    return history, ""

def bot(history, temperature, max_new_tokens, top_p, top_k):
    user_message = history[-1][0]
    bot_response = generate_response(user_message, history[:-1], temperature, max_new_tokens, top_p, top_k)
    history[-1][1] = bot_response
    return history

with gr.Blocks(css=CSS, theme=gr.themes.Default()) as demo:
    gr.HTML(TITLE)
    chatbot = gr.Chatbot(
        [],
        elem_id="chatbot",
        avatar_images=(None, KAMALA_AVATAR),
        height=600,
        bubble_full_width=False,
        show_label=False,
    )
    msg = gr.Textbox(
        placeholder="Ask Kamala Harris a question",
        container=False,
        scale=7
    )
    with gr.Row():
        submit = gr.Button("Submit", scale=1, variant="primary")
        clear = gr.Button("Clear", scale=1)

    with gr.Accordion("Advanced Settings", open=False):
        temperature = gr.Slider(minimum=0.1, maximum=1.5, value=0.8, step=0.1, label="Temperature")
        max_new_tokens = gr.Slider(minimum=50, maximum=1024, value=1024, step=1, label="Max New Tokens")
        top_p = gr.Slider(minimum=0.1, maximum=1.5, value=1.0, step=0.1, label="Top-p")
        top_k = gr.Slider(minimum=1, maximum=100, value=20, step=1, label="Top-k")

    gr.Examples(
        examples=[
            ["What are your thoughts on healthcare reform?"],
            ["How do you plan to address climate change?"],
            ["What's your stance on education policy?"],
        ],
        inputs=msg,
    )

    submit.click(add_text, [chatbot, msg], [chatbot, msg], queue=False).then(
        bot, [chatbot, temperature, max_new_tokens, top_p, top_k], chatbot
    )
    clear.click(lambda: [], outputs=[chatbot], queue=False)
    msg.submit(add_text, [chatbot, msg], [chatbot, msg], queue=False).then(
        bot, [chatbot, temperature, max_new_tokens, top_p, top_k], chatbot
    )

if __name__ == "__main__":
    demo.launch()