import os import torch from typing import List from huggingface_hub import login from transformers import AutoTokenizer, AutoModelForCausalLM import gradio as gr from transformers import AutoTokenizer as SummarizerTokenizer, AutoModelForSeq2SeqLM device = "cuda" if torch.cuda.is_available() else "cpu" # Summarization model summarizer_model_id = "facebook/bart-large-cnn" summarizer_tokenizer = SummarizerTokenizer.from_pretrained(summarizer_model_id) summarizer_model = AutoModelForSeq2SeqLM.from_pretrained( summarizer_model_id, torch_dtype=torch.float16, device_map="auto" ) summarizer_model.to(device) def summarize_text(text: str, max_length=150) -> str: inputs = summarizer_tokenizer([text], return_tensors="pt", max_length=1024, truncation=True).to(device) summary_ids = summarizer_model.generate( inputs['input_ids'], num_beams=4, max_length=max_length, early_stopping=True ) summary = summarizer_tokenizer.decode(summary_ids[0], skip_special_tokens=True) return summary HF_TOKEN = os.environ.get("HF_TOKEN") if HF_TOKEN: login(token=HF_TOKEN) model_id = "deepseek-ai/DeepSeek-R1-Distill-Llama-70B" tokenizer = AutoTokenizer.from_pretrained(model_id, token=HF_TOKEN) model = AutoModelForCausalLM.from_pretrained( model_id, torch_dtype=torch.float16, device_map="auto", low_cpu_mem_usage=True, token=HF_TOKEN ) # --- GIỮ LẠI CHỈ 1 HÀM build_prompt, ĐÃ BỔ SUNG SUMMARIZATION --- def build_prompt(prompt: str, histories: List[str], new_message: str) -> str: prompt_text = prompt.strip() + "\n" if prompt else "" histories_text = "\n".join(histories) if histories else "" # Tóm tắt nếu quá dài (tùy chỉnh ngưỡng này) if len(histories_text) > 3000: histories_text = summarize_text(histories_text, max_length=180) if histories_text: prompt_text += histories_text + "\n" prompt_text += f"User: {new_message}\nAI:" return prompt_text def chat( prompt: str, histories: List[str], new_message: str ) -> str: prompt_text = build_prompt(prompt, histories, new_message) input_ids = tokenizer(prompt_text, return_tensors="pt").input_ids.to(device) with torch.no_grad(): output = model.generate( input_ids, max_new_tokens=256, do_sample=True, top_p=0.95, temperature=0.7, pad_token_id=tokenizer.eos_token_id ) output_text = tokenizer.decode(output[0], skip_special_tokens=True) if "AI:" in output_text: response = output_text.split("AI:")[-1].strip() if "User:" in response: response = response.split("User:")[0].strip() else: response = output_text.strip() return response with gr.Blocks() as demo: gr.Markdown("# MindVR Therapy Chatbot\n\nDùng thử UI hoặc gọi API!") prompt_box = gr.Textbox(lines=2, label="Prompt (System Prompt, chỉ dẫn context cho AI, có thể bỏ trống)") histories_box = gr.Textbox(lines=8, label="Histories (mỗi dòng là một message, ví dụ: User: Xin chào)") new_message_box = gr.Textbox(label="New message") output = gr.Textbox(label="AI Response") def _chat_ui(prompt, histories, new_message): # histories nhập từ UI là multiline string -> chuyển thành list histories_list = [line.strip() for line in histories.split('\n') if line.strip()] return chat(prompt, histories_list, new_message) btn = gr.Button("Gửi") btn.click(_chat_ui, inputs=[prompt_box, histories_box, new_message_box], outputs=output) # API chuẩn RESTful với prompt, histories, new_message gr.api(chat, api_name="chat_ai") demo.launch()