trungtienluong commited on
Commit
6dce667
·
verified ·
1 Parent(s): c3b9512

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +93 -18
app.py CHANGED
@@ -1,23 +1,98 @@
1
  import gradio as gr
 
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
 
 
 
 
3
 
4
- model_name = "HuggingFaceH4/zephyr-7b-beta"
5
- model = AutoModelForCausalLM.from_pretrained(model_name)
6
- tokenizer = AutoTokenizer.from_pretrained(model_name)
7
-
8
- def chatbot_response(input_text):
9
- inputs = tokenizer(input_text, return_tensors="pt")
10
- outputs = model.generate(inputs["input_ids"], max_new_tokens=150)
11
- response = tokenizer.decode(outputs[0], skip_special_tokens=True)
12
- return response
13
-
14
- interface = gr.Interface(
15
- fn=chatbot_response,
16
- inputs="text",
17
- outputs="text",
18
- title="AI Chatbot",
19
- description="Ask me anything!"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  )
21
 
22
- if __name__ == "__main__":
23
- interface.launch()
 
1
  import gradio as gr
2
+ import torch
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
+ from peft import PeftConfig, PeftModel
5
+ import pandas as pd
6
+ from datasets import Dataset, load_dataset
7
+ from sklearn.model_selection import train_test_split
8
 
9
+ MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
10
+
11
+ model = AutoModelForCausalLM.from_pretrained(
12
+ MODEL_NAME,
13
+ device_map="auto",
14
+ trust_remote_code=True
15
+ )
16
+
17
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
18
+ tokenizer.pad_token = tokenizer.eos_token
19
+ model.gradient_checkpointing_enable()
20
+
21
+ # Load the pre-trained model with PEFT
22
+ peft_config = PeftConfig.from_pretrained("trungtienluong/experiments500czephymodelngay11t6l1")
23
+ model = PeftModel.from_pretrained(model, "trungtienluong/experiments500czephymodelngay11t6l1")
24
+
25
+ # Move the model to the appropriate device
26
+ device = "cuda" if torch.cuda.is_available() else "cpu"
27
+ model.to(device)
28
+
29
+ # Load the dataset
30
+ dataset = load_dataset("trungtienluong/500cau")
31
+ data = pd.DataFrame(dataset['train'])
32
+ train_samples, temp_samples = train_test_split(data, test_size=0.2, random_state=42)
33
+ val_samples, test_samples = train_test_split(temp_samples, test_size=0.5, random_state=42)
34
+
35
+ train_dataset = Dataset.from_pandas(train_samples)
36
+ val_dataset = Dataset.from_pandas(val_samples)
37
+ test_dataset = Dataset.from_pandas(test_samples)
38
+
39
+ def create_prompt(question):
40
+ prompt_messages = [
41
+ {"role": "system", "content": "Bạn là một chuyên gia trong lĩnh vực nhi khoa. Hãy trả lời chính xác theo explanation của từng câu. Không thêm thông tin bên ngoài."},
42
+ {"role": "user", "content": "Nhiễm trùng sơ sinh là gì?"},
43
+ {"role": "assistant", "content": "Nhiễm trùng sơ sinh là tình trạng mà một em bé mới sinh bị nhiễm khuẩn hoặc vi rút. Đây là một vấn đề nghiêm trọng có thể ảnh hưởng đến sức khỏe và thậm chí là tính mạng của trẻ sơ sinh."},
44
+ {"role": "user", "content": question}
45
+ ]
46
+ prompt = tokenizer.apply_chat_template(prompt_messages, tokenize=False, add_generation_prompt=True)
47
+ return prompt
48
+
49
+ def post_process_answer(answer):
50
+ lines = answer.split('\n')
51
+ unwanted_tags = ["<system>", "<user>", "<assistant>"]
52
+ filtered_lines = [line for line in lines if not any(tag in line for tag in unwanted_tags)]
53
+ return "\n".join(filtered_lines).strip()
54
+
55
+ def generate_answer(question):
56
+ try:
57
+ prompt = create_prompt(question)
58
+ encoding = tokenizer(prompt, return_tensors="pt").to(device)
59
+ with torch.inference_mode():
60
+ outputs = model.generate(
61
+ input_ids=encoding.input_ids,
62
+ attention_mask=encoding.attention_mask,
63
+ max_new_tokens=150
64
+ )
65
+ answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
66
+ processed_answer = post_process_answer(answer)
67
+ print(f"Generated answer: {processed_answer}")
68
+ return processed_answer
69
+ except Exception as e:
70
+ print(f"Error generating answer: {e}")
71
+ return "Error"
72
+
73
+ def get_random_test_question():
74
+ random_question = test_dataset.shuffle(seed=42)['Question'][0]
75
+ return random_question
76
+
77
+ def interface_generate_answer(question, use_test_question):
78
+ if use_test_question:
79
+ question = get_random_test_question()
80
+ answer = generate_answer(question)
81
+ return question, answer
82
+
83
+ iface = gr.Interface(
84
+ fn=interface_generate_answer,
85
+ inputs=[
86
+ gr.Textbox(lines=2, placeholder="Nhập câu hỏi của bạn ở đây...", label="Câu hỏi"),
87
+ gr.Checkbox(label="Sử dụng câu hỏi từ tập kiểm tra")
88
+ ],
89
+ outputs=[
90
+ gr.Textbox(label="Câu hỏi đã nhập hoặc từ tập kiểm tra"),
91
+ gr.Textbox(label="Câu trả lời")
92
+ ],
93
+ title="Chatbot Nhi khoa",
94
+ description="Hỏi bất kỳ câu hỏi nào về nhi khoa. Bạn có thể chọn sử dụng câu hỏi từ tập kiểm tra.",
95
+ theme="default"
96
  )
97
 
98
+ iface.launch(share=True)