import torch import gradio as gr from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline # ---------------- CONFIG ---------------- MODEL_NAME = "google/flan-t5-base" # βœ… CPU-friendly model SYSTEM_PROMPT_DEFAULT = ( "You are a formal and polite AI assistant. " "Always respond appropriately depending on the selected explanation style." ) MAX_NEW_TOKENS_DEFAULT = 256 TEMP_DEFAULT = 0.7 TOP_P_DEFAULT = 0.9 # ---------------- LOAD MODEL ---------------- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) model = AutoModelForSeq2SeqLM.from_pretrained( MODEL_NAME, torch_dtype=torch.float32 # βœ… safe for CPU ) generator = pipeline( "text2text-generation", # βœ… T5 models use this, not causal LM model=model, tokenizer=tokenizer, device=-1 # βœ… force CPU ) # ---------------- HELPERS ---------------- def format_prompt(chat_history, user_message, system_message, response_style): # Start with system message prompt = system_message + "\n\n" # Append previous conversation content only for turn in chat_history: prompt += f"{turn['content']}\n" # Append the new user message prompt += user_message # Optional: add response style instruction if response_style == "No explanation": prompt += " Provide only the answer." elif response_style == "Short explanation": prompt += " Provide a short one-sentence explanation." elif response_style == "Detailed explanation": prompt += " Provide a detailed explanation." return prompt # ---------------- CHAT FUNCTION ---------------- def chat(user_message, chat_history, system_message, max_tokens, temperature, top_p, response_style): chat_history = chat_history or [] prompt = format_prompt(chat_history, user_message, system_message, response_style) output = generator( prompt, max_new_tokens=max_tokens, do_sample=True, temperature=temperature, top_p=top_p, )[0]['generated_text'] # Sometimes Flan-T5 outputs the question in the result, so strip it response = output.replace(prompt, "").strip() # Save user and assistant content without labels chat_history.append({"role": "user", "content": user_message}) chat_history.append({"role": "assistant", "content": response}) return "", chat_history # ---------------- UI ---------------- with gr.Blocks(theme=gr.themes.Soft(primary_hue="violet", secondary_hue="pink")) as demo: gr.Markdown("# 🧠 FLAN-T5-Base Chat Assistant (CPU-safe)") chatbot = gr.Chatbot(type="messages", height=500, show_copy_button=True) with gr.Row(): msg = gr.Textbox(label="πŸ’¬ Your Message", placeholder="Type here…", scale=6) send_btn = gr.Button("πŸš€ Send", variant="primary", scale=1) clear_btn = gr.Button("🧹 Clear Chat", scale=1) with gr.Accordion("βš™οΈ Advanced Settings", open=False): system_prompt = gr.Textbox(label="System Prompt", value=SYSTEM_PROMPT_DEFAULT, lines=3) response_style = gr.Dropdown( ["No explanation", "Short explanation", "Detailed explanation"], value="Detailed explanation", label="Response Style" ) temperature = gr.Slider(0.1, 1.5, value=TEMP_DEFAULT, step=0.1, label="Temperature") top_p = gr.Slider(0.1, 1.0, value=TOP_P_DEFAULT, step=0.05, label="Top-p") max_tokens = gr.Slider(32, 512, value=MAX_NEW_TOKENS_DEFAULT, step=16, label="Max new tokens") send_btn.click( chat, [msg, chatbot, system_prompt, max_tokens, temperature, top_p, response_style], [msg, chatbot] ) msg.submit( chat, [msg, chatbot, system_prompt, max_tokens, temperature, top_p, response_style], [msg, chatbot] ) clear_btn.click(lambda: [], None, chatbot, queue=False) if __name__ == "__main__": demo.launch()