File size: 4,028 Bytes
8d81ea9
f92b0d5
ce493a4
f92b0d5
5214a6c
ce493a4
5f4efa7
590a39e
 
5f4efa7
f92b0d5
590a39e
6774afa
f7a5317
f92b0d5
590a39e
 
ce493a4
 
6774afa
ce493a4
6774afa
 
590a39e
ce493a4
590a39e
 
ce493a4
17da298
b4f77ad
 
590a39e
ce493a4
590a39e
ce493a4
 
590a39e
3a78f32
f6c62f1
ce493a4
 
f6c62f1
ce493a4
 
70e88c8
f6c62f1
70e88c8
f6c62f1
f7a5317
f6c62f1
ce493a4
590a39e
f92b0d5
ce493a4
590a39e
 
 
 
ce493a4
590a39e
 
 
 
 
 
ce493a4
 
 
 
 
 
590a39e
 
ce493a4
590a39e
5214a6c
ce493a4
5214a6c
 
ce493a4
5214a6c
590a39e
f92b0d5
5214a6c
af25cff
4e13938
 
f92b0d5
4e13938
b4f77ad
0a2169b
70e88c8
f7a5317
70e88c8
0a2169b
5214a6c
 
cac7ff6
f7a5317
590a39e
 
 
 
 
 
 
 
 
 
08ea239
f92b0d5
 
d02b539
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import torch
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline

# ---------------- CONFIG ----------------
MODEL_NAME = "google/gemma-3-270m-it"   # ✅ instruction-tuned Gemma 3 model
SYSTEM_PROMPT_DEFAULT = (
    "You are a formal and polite AI assistant. "
    "Always respond appropriately depending on the selected explanation style."
)

MAX_NEW_TOKENS_DEFAULT = 256
TEMP_DEFAULT = 0.7
TOP_P_DEFAULT = 0.9

# ---------------- LOAD MODEL ----------------
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.float32,   # ✅ safe for CPU
)

generator = pipeline(
    "text-generation",   # ✅ causal LM (not seq2seq)
    model=model,
    tokenizer=tokenizer,
    device=-1   # ✅ force CPU
)

# ---------------- HELPERS ----------------
def format_prompt(chat_history, user_message, system_message, response_style):
    # Start with system message
    prompt = system_message + "\n\n"

    # Add only user messages (optional: you can also add last assistant reply if needed)
    for turn in chat_history:
        if turn["role"] == "user":
            prompt += f"{turn['content']}\n"

    # Add the new user message
    prompt += f"{user_message}\n"

    # Optionally instruct for explanation style
    if response_style == "No explanation":
        prompt += " Answer concisely with no explanation."
    elif response_style == "Short explanation":
        prompt += " Answer briefly with a one-sentence explanation."
    elif response_style == "Detailed explanation":
        prompt += " Answer in detail with reasoning and examples."

    return prompt


# ---------------- CHAT FUNCTION ----------------
def chat(user_message, chat_history, system_message, max_tokens, temperature, top_p, response_style):
    chat_history = chat_history or []
    prompt = format_prompt(chat_history, user_message, system_message, response_style)

    output = generator(
        prompt,
        max_new_tokens=max_tokens,
        do_sample=True,
        temperature=temperature,
        top_p=top_p,
    )[0]['generated_text']

    # For causal LMs, output includes the prompt → strip it
    response = output[len(prompt):].strip()

    # Save user and assistant content without labels
    chat_history.append({"role": "user", "content": user_message})
    chat_history.append({"role": "assistant", "content": response})

    return "", chat_history


# ---------------- UI ----------------
with gr.Blocks(theme=gr.themes.Soft(primary_hue="violet", secondary_hue="pink")) as demo:
    gr.Markdown("# 🧠 Gemma-3-270M Chat Assistant (CPU-safe)")

    chatbot = gr.Chatbot(type="messages", height=500, show_copy_button=True)

    with gr.Row():
        msg = gr.Textbox(label="💬 Your Message", placeholder="Type here…", scale=6)
        send_btn = gr.Button("🚀 Send", variant="primary", scale=1)
        clear_btn = gr.Button("🧹 Clear Chat", scale=1)

    with gr.Accordion("⚙️ Advanced Settings", open=False):
        system_prompt = gr.Textbox(label="System Prompt", value=SYSTEM_PROMPT_DEFAULT, lines=3)
        response_style = gr.Dropdown(
            ["No explanation", "Short explanation", "Detailed explanation"],
            value="Detailed explanation",
            label="Response Style"
        )
        temperature = gr.Slider(0.1, 1.5, value=TEMP_DEFAULT, step=0.1, label="Temperature")
        top_p = gr.Slider(0.1, 1.0, value=TOP_P_DEFAULT, step=0.05, label="Top-p")
        max_tokens = gr.Slider(32, 512, value=MAX_NEW_TOKENS_DEFAULT, step=16, label="Max new tokens")

    send_btn.click(
        chat,
        [msg, chatbot, system_prompt, max_tokens, temperature, top_p, response_style],
        [msg, chatbot]
    )
    msg.submit(
        chat,
        [msg, chatbot, system_prompt, max_tokens, temperature, top_p, response_style],
        [msg, chatbot]
    )
    clear_btn.click(lambda: [], None, chatbot, queue=False)

if __name__ == "__main__":
    demo.launch()