Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -1,23 +1,98 @@
|
|
1 |
import gradio as gr
|
|
|
2 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
|
|
|
|
|
|
|
3 |
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
)
|
21 |
|
22 |
-
|
23 |
-
interface.launch()
|
|
|
1 |
import gradio as gr
|
2 |
+
import torch
|
3 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
4 |
+
from peft import PeftConfig, PeftModel
|
5 |
+
import pandas as pd
|
6 |
+
from datasets import Dataset, load_dataset
|
7 |
+
from sklearn.model_selection import train_test_split
|
8 |
|
9 |
+
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
|
10 |
+
|
11 |
+
model = AutoModelForCausalLM.from_pretrained(
|
12 |
+
MODEL_NAME,
|
13 |
+
device_map="auto",
|
14 |
+
trust_remote_code=True
|
15 |
+
)
|
16 |
+
|
17 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
18 |
+
tokenizer.pad_token = tokenizer.eos_token
|
19 |
+
model.gradient_checkpointing_enable()
|
20 |
+
|
21 |
+
# Load the pre-trained model with PEFT
|
22 |
+
peft_config = PeftConfig.from_pretrained("trungtienluong/experiments500czephymodelngay11t6l1")
|
23 |
+
model = PeftModel.from_pretrained(model, "trungtienluong/experiments500czephymodelngay11t6l1")
|
24 |
+
|
25 |
+
# Move the model to the appropriate device
|
26 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
27 |
+
model.to(device)
|
28 |
+
|
29 |
+
# Load the dataset
|
30 |
+
dataset = load_dataset("trungtienluong/500cau")
|
31 |
+
data = pd.DataFrame(dataset['train'])
|
32 |
+
train_samples, temp_samples = train_test_split(data, test_size=0.2, random_state=42)
|
33 |
+
val_samples, test_samples = train_test_split(temp_samples, test_size=0.5, random_state=42)
|
34 |
+
|
35 |
+
train_dataset = Dataset.from_pandas(train_samples)
|
36 |
+
val_dataset = Dataset.from_pandas(val_samples)
|
37 |
+
test_dataset = Dataset.from_pandas(test_samples)
|
38 |
+
|
39 |
+
def create_prompt(question):
|
40 |
+
prompt_messages = [
|
41 |
+
{"role": "system", "content": "Bạn là một chuyên gia trong lĩnh vực nhi khoa. Hãy trả lời chính xác theo explanation của từng câu. Không thêm thông tin bên ngoài."},
|
42 |
+
{"role": "user", "content": "Nhiễm trùng sơ sinh là gì?"},
|
43 |
+
{"role": "assistant", "content": "Nhiễm trùng sơ sinh là tình trạng mà một em bé mới sinh bị nhiễm khuẩn hoặc vi rút. Đây là một vấn đề nghiêm trọng có thể ảnh hưởng đến sức khỏe và thậm chí là tính mạng của trẻ sơ sinh."},
|
44 |
+
{"role": "user", "content": question}
|
45 |
+
]
|
46 |
+
prompt = tokenizer.apply_chat_template(prompt_messages, tokenize=False, add_generation_prompt=True)
|
47 |
+
return prompt
|
48 |
+
|
49 |
+
def post_process_answer(answer):
|
50 |
+
lines = answer.split('\n')
|
51 |
+
unwanted_tags = ["<system>", "<user>", "<assistant>"]
|
52 |
+
filtered_lines = [line for line in lines if not any(tag in line for tag in unwanted_tags)]
|
53 |
+
return "\n".join(filtered_lines).strip()
|
54 |
+
|
55 |
+
def generate_answer(question):
|
56 |
+
try:
|
57 |
+
prompt = create_prompt(question)
|
58 |
+
encoding = tokenizer(prompt, return_tensors="pt").to(device)
|
59 |
+
with torch.inference_mode():
|
60 |
+
outputs = model.generate(
|
61 |
+
input_ids=encoding.input_ids,
|
62 |
+
attention_mask=encoding.attention_mask,
|
63 |
+
max_new_tokens=150
|
64 |
+
)
|
65 |
+
answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
66 |
+
processed_answer = post_process_answer(answer)
|
67 |
+
print(f"Generated answer: {processed_answer}")
|
68 |
+
return processed_answer
|
69 |
+
except Exception as e:
|
70 |
+
print(f"Error generating answer: {e}")
|
71 |
+
return "Error"
|
72 |
+
|
73 |
+
def get_random_test_question():
|
74 |
+
random_question = test_dataset.shuffle(seed=42)['Question'][0]
|
75 |
+
return random_question
|
76 |
+
|
77 |
+
def interface_generate_answer(question, use_test_question):
|
78 |
+
if use_test_question:
|
79 |
+
question = get_random_test_question()
|
80 |
+
answer = generate_answer(question)
|
81 |
+
return question, answer
|
82 |
+
|
83 |
+
iface = gr.Interface(
|
84 |
+
fn=interface_generate_answer,
|
85 |
+
inputs=[
|
86 |
+
gr.Textbox(lines=2, placeholder="Nhập câu hỏi của bạn ở đây...", label="Câu hỏi"),
|
87 |
+
gr.Checkbox(label="Sử dụng câu hỏi từ tập kiểm tra")
|
88 |
+
],
|
89 |
+
outputs=[
|
90 |
+
gr.Textbox(label="Câu hỏi đã nhập hoặc từ tập kiểm tra"),
|
91 |
+
gr.Textbox(label="Câu trả lời")
|
92 |
+
],
|
93 |
+
title="Chatbot Nhi khoa",
|
94 |
+
description="Hỏi bất kỳ câu hỏi nào về nhi khoa. Bạn có thể chọn sử dụng câu hỏi từ tập kiểm tra.",
|
95 |
+
theme="default"
|
96 |
)
|
97 |
|
98 |
+
iface.launch(share=True)
|
|