File size: 3,956 Bytes
8d81ea9
f92b0d5
cac7ff6
f92b0d5
5214a6c
cac7ff6
5f4efa7
590a39e
 
5f4efa7
f92b0d5
590a39e
6774afa
f7a5317
f92b0d5
590a39e
 
6774afa
cac7ff6
6774afa
cac7ff6
6774afa
 
590a39e
cac7ff6
590a39e
 
cac7ff6
17da298
b4f77ad
 
590a39e
8dba69a
590a39e
385b181
8dba69a
590a39e
8dba69a
385b181
8dba69a
 
590a39e
8dba69a
70e88c8
8dba69a
70e88c8
8dba69a
f7a5317
8dba69a
590a39e
 
f92b0d5
590a39e
 
 
 
 
 
 
 
 
 
 
 
 
8dba69a
 
590a39e
8dba69a
590a39e
 
 
 
5214a6c
 
 
cac7ff6
5214a6c
590a39e
f92b0d5
5214a6c
af25cff
4e13938
 
f92b0d5
4e13938
b4f77ad
0a2169b
70e88c8
f7a5317
70e88c8
0a2169b
5214a6c
 
cac7ff6
f7a5317
590a39e
 
 
 
 
 
 
 
 
 
08ea239
f92b0d5
 
d02b539
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import torch
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline

# ---------------- CONFIG ----------------
MODEL_NAME = "google/flan-t5-base"   # ✅ CPU-friendly model
SYSTEM_PROMPT_DEFAULT = (
    "You are a formal and polite AI assistant. "
    "Always respond appropriately depending on the selected explanation style."
)

MAX_NEW_TOKENS_DEFAULT = 256
TEMP_DEFAULT = 0.7
TOP_P_DEFAULT = 0.9

# ---------------- LOAD MODEL ----------------
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

model = AutoModelForSeq2SeqLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.float32   # ✅ safe for CPU
)

generator = pipeline(
    "text2text-generation",   # ✅ T5 models use this, not causal LM
    model=model,
    tokenizer=tokenizer,
    device=-1   # ✅ force CPU
)

# ---------------- HELPERS ----------------
def format_prompt(chat_history, user_message, system_message, response_style):
    # Start with system message
    prompt = system_message + "\n\n"
    
    # Append previous conversation content only
    for turn in chat_history:
        prompt += f"{turn['content']}\n"
    
    # Append the new user message
    prompt += user_message
    
    # Optional: add response style instruction
    if response_style == "No explanation":
        prompt += " Provide only the answer."
    elif response_style == "Short explanation":
        prompt += " Provide a short one-sentence explanation."
    elif response_style == "Detailed explanation":
        prompt += " Provide a detailed explanation."
    
    return prompt

# ---------------- CHAT FUNCTION ----------------
def chat(user_message, chat_history, system_message, max_tokens, temperature, top_p, response_style):
    chat_history = chat_history or []
    prompt = format_prompt(chat_history, user_message, system_message, response_style)
    
    output = generator(
        prompt,
        max_new_tokens=max_tokens,
        do_sample=True,
        temperature=temperature,
        top_p=top_p,
    )[0]['generated_text']
    
    # Sometimes Flan-T5 outputs the question in the result, so strip it
    response = output.replace(prompt, "").strip()
    
    # Save user and assistant content without labels
    chat_history.append({"role": "user", "content": user_message})
    chat_history.append({"role": "assistant", "content": response})
    
    return "", chat_history

# ---------------- UI ----------------
with gr.Blocks(theme=gr.themes.Soft(primary_hue="violet", secondary_hue="pink")) as demo:
    gr.Markdown("# 🧠 FLAN-T5-Base Chat Assistant (CPU-safe)")

    chatbot = gr.Chatbot(type="messages", height=500, show_copy_button=True)

    with gr.Row():
        msg = gr.Textbox(label="💬 Your Message", placeholder="Type here…", scale=6)
        send_btn = gr.Button("🚀 Send", variant="primary", scale=1)
        clear_btn = gr.Button("🧹 Clear Chat", scale=1)

    with gr.Accordion("⚙️ Advanced Settings", open=False):
        system_prompt = gr.Textbox(label="System Prompt", value=SYSTEM_PROMPT_DEFAULT, lines=3)
        response_style = gr.Dropdown(
            ["No explanation", "Short explanation", "Detailed explanation"],
            value="Detailed explanation",
            label="Response Style"
        )
        temperature = gr.Slider(0.1, 1.5, value=TEMP_DEFAULT, step=0.1, label="Temperature")
        top_p = gr.Slider(0.1, 1.0, value=TOP_P_DEFAULT, step=0.05, label="Top-p")
        max_tokens = gr.Slider(32, 512, value=MAX_NEW_TOKENS_DEFAULT, step=16, label="Max new tokens")

    send_btn.click(
        chat,
        [msg, chatbot, system_prompt, max_tokens, temperature, top_p, response_style],
        [msg, chatbot]
    )
    msg.submit(
        chat,
        [msg, chatbot, system_prompt, max_tokens, temperature, top_p, response_style],
        [msg, chatbot]
    )
    clear_btn.click(lambda: [], None, chatbot, queue=False)

if __name__ == "__main__":
    demo.launch()