mindvridge / app.py
MindVR's picture
Update app.py
5cbcb86 verified
raw
history blame
1.91 kB
import os
import torch
from huggingface_hub import login
from transformers import AutoTokenizer, AutoModelForCausalLM
import gradio as gr
HF_TOKEN = os.environ.get("HF_TOKEN")
if HF_TOKEN:
login(token=HF_TOKEN)
model_id = "MindVR/JohnTran_Fine-tune"
tokenizer = AutoTokenizer.from_pretrained(model_id, token=HF_TOKEN)
model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map="auto",
low_cpu_mem_usage=True,
token=HF_TOKEN
)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
def build_prompt(history, new_message):
prompt = ""
if history:
prompt += "\n".join(history) + "\n"
prompt += f"User: {new_message}\nAI:"
return prompt
def chat(history, new_message):
# Đảm bảo history là list (khi nhập trực tiếp trên UI đôi khi là str)
if isinstance(history, str):
import ast
try:
history = ast.literal_eval(history)
except:
history = [history]
prompt = build_prompt(history, new_message)
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
with torch.no_grad():
output = model.generate(
input_ids,
max_new_tokens=200,
do_sample=True,
top_p=0.95,
temperature=0.7,
pad_token_id=tokenizer.eos_token_id
)
output_text = tokenizer.decode(output[0], skip_special_tokens=True)
if "AI:" in output_text:
response = output_text.split("AI:")[-1].strip()
else:
response = output_text.strip()
return response
iface = gr.Interface(
fn=chat,
inputs=[
gr.Textbox(lines=8, label="History (JSON list, ví dụ: [\"User: Xin chào\"] )"),
gr.Textbox(label="New message")
],
outputs=gr.Textbox(label="AI Response"),
title="MindVR Therapy Chatbot",
allow_flagging="never"
)
iface.launch()