Spaces:
Sleeping
Sleeping
File size: 3,956 Bytes
8d81ea9 f92b0d5 cac7ff6 f92b0d5 5214a6c cac7ff6 5f4efa7 590a39e 5f4efa7 f92b0d5 590a39e 6774afa f7a5317 f92b0d5 590a39e 6774afa cac7ff6 6774afa cac7ff6 6774afa 590a39e cac7ff6 590a39e cac7ff6 17da298 b4f77ad 590a39e 8dba69a 590a39e 385b181 8dba69a 590a39e 8dba69a 385b181 8dba69a 590a39e 8dba69a 70e88c8 8dba69a 70e88c8 8dba69a f7a5317 8dba69a 590a39e f92b0d5 590a39e 8dba69a 590a39e 8dba69a 590a39e 5214a6c cac7ff6 5214a6c 590a39e f92b0d5 5214a6c af25cff 4e13938 f92b0d5 4e13938 b4f77ad 0a2169b 70e88c8 f7a5317 70e88c8 0a2169b 5214a6c cac7ff6 f7a5317 590a39e 08ea239 f92b0d5 d02b539 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 |
import torch
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
# ---------------- CONFIG ----------------
MODEL_NAME = "google/flan-t5-base" # ✅ CPU-friendly model
SYSTEM_PROMPT_DEFAULT = (
"You are a formal and polite AI assistant. "
"Always respond appropriately depending on the selected explanation style."
)
MAX_NEW_TOKENS_DEFAULT = 256
TEMP_DEFAULT = 0.7
TOP_P_DEFAULT = 0.9
# ---------------- LOAD MODEL ----------------
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForSeq2SeqLM.from_pretrained(
MODEL_NAME,
torch_dtype=torch.float32 # ✅ safe for CPU
)
generator = pipeline(
"text2text-generation", # ✅ T5 models use this, not causal LM
model=model,
tokenizer=tokenizer,
device=-1 # ✅ force CPU
)
# ---------------- HELPERS ----------------
def format_prompt(chat_history, user_message, system_message, response_style):
# Start with system message
prompt = system_message + "\n\n"
# Append previous conversation content only
for turn in chat_history:
prompt += f"{turn['content']}\n"
# Append the new user message
prompt += user_message
# Optional: add response style instruction
if response_style == "No explanation":
prompt += " Provide only the answer."
elif response_style == "Short explanation":
prompt += " Provide a short one-sentence explanation."
elif response_style == "Detailed explanation":
prompt += " Provide a detailed explanation."
return prompt
# ---------------- CHAT FUNCTION ----------------
def chat(user_message, chat_history, system_message, max_tokens, temperature, top_p, response_style):
chat_history = chat_history or []
prompt = format_prompt(chat_history, user_message, system_message, response_style)
output = generator(
prompt,
max_new_tokens=max_tokens,
do_sample=True,
temperature=temperature,
top_p=top_p,
)[0]['generated_text']
# Sometimes Flan-T5 outputs the question in the result, so strip it
response = output.replace(prompt, "").strip()
# Save user and assistant content without labels
chat_history.append({"role": "user", "content": user_message})
chat_history.append({"role": "assistant", "content": response})
return "", chat_history
# ---------------- UI ----------------
with gr.Blocks(theme=gr.themes.Soft(primary_hue="violet", secondary_hue="pink")) as demo:
gr.Markdown("# 🧠 FLAN-T5-Base Chat Assistant (CPU-safe)")
chatbot = gr.Chatbot(type="messages", height=500, show_copy_button=True)
with gr.Row():
msg = gr.Textbox(label="💬 Your Message", placeholder="Type here…", scale=6)
send_btn = gr.Button("🚀 Send", variant="primary", scale=1)
clear_btn = gr.Button("🧹 Clear Chat", scale=1)
with gr.Accordion("⚙️ Advanced Settings", open=False):
system_prompt = gr.Textbox(label="System Prompt", value=SYSTEM_PROMPT_DEFAULT, lines=3)
response_style = gr.Dropdown(
["No explanation", "Short explanation", "Detailed explanation"],
value="Detailed explanation",
label="Response Style"
)
temperature = gr.Slider(0.1, 1.5, value=TEMP_DEFAULT, step=0.1, label="Temperature")
top_p = gr.Slider(0.1, 1.0, value=TOP_P_DEFAULT, step=0.05, label="Top-p")
max_tokens = gr.Slider(32, 512, value=MAX_NEW_TOKENS_DEFAULT, step=16, label="Max new tokens")
send_btn.click(
chat,
[msg, chatbot, system_prompt, max_tokens, temperature, top_p, response_style],
[msg, chatbot]
)
msg.submit(
chat,
[msg, chatbot, system_prompt, max_tokens, temperature, top_p, response_style],
[msg, chatbot]
)
clear_btn.click(lambda: [], None, chatbot, queue=False)
if __name__ == "__main__":
demo.launch() |