danhtran2mind commited on
Commit
25effb7
·
verified ·
1 Parent(s): 3715ad1

Upload 12 files

Browse files
README.md CHANGED
@@ -1,13 +1,13 @@
1
- ---
2
- title: Project Medical Chat App
3
- emoji: 💻
4
- colorFrom: gray
5
- colorTo: indigo
6
- sdk: gradio
7
- sdk_version: 5.38.2
8
- app_file: app.py
9
- pinned: false
10
- license: mit
11
- ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
+ ---
2
+ title: Gradio Chat App
3
+ emoji: 💻
4
+ colorFrom: red
5
+ colorTo: green
6
+ sdk: gradio
7
+ sdk_version: 5.38.1
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
gradio_app.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ sys.path.append(os.path.join(os.path.dirname(__file__), 'gradio_app'))
4
+
5
+ from config import logger, MODEL_IDS
6
+ from model_handler import ModelHandler
7
+ from generator import generate_response
8
+
9
+ DESCRIPTION = '''
10
+ <h1><span class="intro-icon">⚕️</span> Medical Chatbot with LoRA Models</h1>
11
+ <h2>AI-Powered Medical Insights</h2>
12
+ <div class="intro-highlight">
13
+ <strong>Explore our advanced models, fine-tuned with LoRA for medical reasoning in Vietnamese.</strong>
14
+ </div>
15
+ <div class="intro-disclaimer">
16
+ <strong><span class="intro-icon">ℹ️</span> Notice:</strong> For research purposes only. AI responses may have limitations due to development, datasets, and architecture. <strong>Always consult a medical professional for health advice 🩺</strong>.
17
+ </div>
18
+ '''
19
+
20
+ def user(message, history):
21
+ if not isinstance(history, list):
22
+ history = []
23
+ return "", history + [[message, None]]
24
+
25
+ def create_ui(model_handler):
26
+ with gr.Blocks(css="static/style.css", theme=gr.themes.Default()) as demo:
27
+ gr.Markdown(DESCRIPTION)
28
+ gr.HTML('<script src="static/script.js"></script>')
29
+ active_gen = gr.State([False])
30
+ model_handler_state = gr.State(model_handler)
31
+
32
+ chatbot = gr.Chatbot(
33
+ elem_id="chatbot",
34
+ height=500,
35
+ show_label=False,
36
+ render_markdown=True
37
+ )
38
+
39
+ with gr.Row():
40
+ msg = gr.Textbox(
41
+ label="Message",
42
+ placeholder="Type your medical query in Vietnamese...",
43
+ container=False,
44
+ scale=4
45
+ )
46
+ submit_btn = gr.Button("Send", variant='primary', scale=1)
47
+
48
+ with gr.Column(scale=2):
49
+ with gr.Row():
50
+ clear_btn = gr.Button("Clear", variant='secondary')
51
+ stop_btn = gr.Button("Stop", variant='stop')
52
+
53
+ with gr.Accordion("Parameters", open=False):
54
+ model_dropdown = gr.Dropdown(
55
+ choices=MODEL_IDS,
56
+ value=MODEL_IDS[0],
57
+ label="Select Model",
58
+ interactive=True
59
+ )
60
+ temperature = gr.Slider(minimum=0.1, maximum=1.5, value=0.7, label="Temperature")
61
+ top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.95, label="Top-p")
62
+ top_k = gr.Slider(minimum=1, maximum=100, value=64, step=1, label="Top-k")
63
+ max_tokens = gr.Slider(minimum=128, maximum=4084, value=512, step=32, label="Max Tokens")
64
+ seed = gr.Slider(minimum=0, maximum=2**32, value=42, step=1, label="Random Seed")
65
+ auto_clear = gr.Checkbox(label="Auto Clear History", value=True,
66
+ info="Clears internal conversation history after each response but keeps displayed previous messages.")
67
+
68
+ gr.Examples(
69
+ examples=[
70
+ ["Khi nghi ngờ bị loét dạ dày tá tràng nên đến khoa nào tại bệnh viện để thăm khám?"],
71
+ ["Triệu chứng của loét dạ dày tá tràng là gì?"],
72
+ ["Tôi bị mất ngủ, tôi phải làm gì?"],
73
+ ["Tôi bị trĩ, tôi có nên mổ không?"]
74
+ ],
75
+ inputs=msg,
76
+ label="Example Medical Queries"
77
+ )
78
+
79
+ model_load_output = gr.Textbox(label="Model Load Status")
80
+ model_dropdown.change(
81
+ fn=model_handler.load_model,
82
+ inputs=[model_dropdown, chatbot],
83
+ outputs=[model_load_output, chatbot]
84
+ )
85
+
86
+ submit_event = submit_btn.click(
87
+ fn=user,
88
+ inputs=[msg, chatbot],
89
+ outputs=[msg, chatbot],
90
+ queue=False
91
+ ).then(
92
+ fn=lambda: [True],
93
+ outputs=active_gen
94
+ ).then(
95
+ fn=generate_response,
96
+ inputs=[model_handler_state, chatbot, temperature, top_p, top_k, max_tokens, seed, active_gen, model_dropdown, auto_clear],
97
+ outputs=chatbot
98
+ )
99
+
100
+ msg.submit(
101
+ fn=user,
102
+ inputs=[msg, chatbot],
103
+ outputs=[msg, chatbot],
104
+ queue=False
105
+ ).then(
106
+ fn=lambda: [True],
107
+ outputs=active_gen
108
+ ).then(
109
+ fn=generate_response,
110
+ inputs=[model_handler_state, chatbot, temperature, top_p, top_k, max_tokens, seed, active_gen, model_dropdown, auto_clear],
111
+ outputs=chatbot
112
+ )
113
+
114
+ stop_btn.click(
115
+ fn=lambda: [False],
116
+ inputs=None,
117
+ outputs=active_gen,
118
+ cancels=[submit_event]
119
+ )
120
+
121
+ clear_btn.click(
122
+ fn=lambda: None,
123
+ inputs=None,
124
+ outputs=chatbot,
125
+ queue=False
126
+ )
127
+
128
+ return demo
129
+
130
+ def main():
131
+ model_handler = ModelHandler()
132
+ model_handler.load_model(MODEL_IDS[0], [])
133
+ demo = create_ui(model_handler)
134
+ try:
135
+ demo.launch(server_name="0.0.0.0", server_port=7860)
136
+ except Exception as e:
137
+ logger.error(f"Failed to launch Gradio app: {str(e)}")
138
+ raise
139
+
140
+ if __name__ == "__main__":
141
+ main()
gradio_app/__init__.py ADDED
File without changes
gradio_app/config.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+
3
+ # Configure logging
4
+ logging.basicConfig(level=logging.INFO)
5
+ logger = logging.getLogger(__name__)
6
+
7
+ # LoRA configurations
8
+ LORA_CONFIGS = {
9
+ "Gemma-3-1B-Instruct-Vi-Medical-LoRA": {
10
+ "base_model": "unsloth/gemma-3-1b-it",
11
+ "lora_adapter": "danhtran2mind/Gemma-3-1B-Instruct-Vi-Medical-LoRA"
12
+ },
13
+ "Gemma-3-1B-GRPO-Vi-Medical-LoRA": {
14
+ "base_model": "unsloth/gemma-3-1b-it",
15
+ "lora_adapter": "danhtran2mind/Gemma-3-1B-GRPO-Vi-Medical-LoRA"
16
+ },
17
+ "Llama-3.2-3B-Instruct-Vi-Medical-LoRA": {
18
+ "base_model": "unsloth/Llama-3.2-3B-Instruct",
19
+ "lora_adapter": "danhtran2mind/Llama-3.2-3B-Instruct-Vi-Medical-LoRA"
20
+ },
21
+ "Llama-3.2-1B-Instruct-Vi-Medical-LoRA": {
22
+ "base_model": "unsloth/Llama-3.2-1B-Instruct",
23
+ "lora_adapter": "danhtran2mind/Llama-3.2-1B-Instruct-Vi-Medical-LoRA"
24
+ },
25
+ "Llama-3.2-3B-Reasoning-Vi-Medical-LoRA": {
26
+ "base_model": "unsloth/Llama-3.2-3B-Instruct",
27
+ "lora_adapter": "danhtran2mind/Llama-3.2-3B-Reasoning-Vi-Medical-LoRA"
28
+ },
29
+ "Qwen-3-0.6B-Instruct-Vi-Medical-LoRA": {
30
+ "base_model": "Qwen/Qwen3-0.6B",
31
+ "lora_adapter": "danhtran2mind/Qwen-3-0.6B-Instruct-Vi-Medical-LoRA"
32
+ },
33
+ "Qwen-3-0.6B-Reasoning-Vi-Medical-LoRA": {
34
+ "base_model": "Qwen/Qwen3-0.6B",
35
+ "lora_adapter": "danhtran2mind/Qwen-3-0.6B-Reasoning-Vi-Medical-LoRA"
36
+ }
37
+ }
38
+
39
+ # Model settings
40
+ MAX_INPUT_TOKEN_LENGTH = 4096
41
+ DEFAULT_MAX_NEW_TOKENS = 512
42
+ MAX_MAX_NEW_TOKENS = 2048
43
+
44
+ MODEL_IDS = list(LORA_CONFIGS.keys())
gradio_app/generator.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import re
3
+ import torch
4
+ from threading import Thread
5
+ from transformers import TextIteratorStreamer
6
+ from config import logger, MAX_INPUT_TOKEN_LENGTH
7
+ from prompts import PROMPT_FUNCTIONS
8
+ from response_parser import ParserState, parse_response, format_response, remove_tags
9
+ from utils import merge_conversation
10
+
11
+ def generate_response(model_handler, history, temperature, top_p, top_k, max_tokens, seed, active_gen, model_id, auto_clear):
12
+ raw_history = copy.deepcopy(history)
13
+
14
+ # Clean history by removing tags from assistant responses
15
+ history = [[item[0], remove_tags(item[1]) if item[1] else None] for item in history]
16
+
17
+ try:
18
+ # Validate history
19
+ if not isinstance(history, list) or not history:
20
+ logger.error("History is empty or not a list")
21
+ history = [[None, "Error: Conversation history is empty or invalid"]]
22
+ yield history
23
+ return
24
+ # Validate last history entry
25
+ if not isinstance(history[-1], (list, tuple)) or len(history[-1]) < 1 or not history[-1][0]:
26
+ logger.error("Last history entry is invalid or missing user message")
27
+ history = raw_history
28
+ history[-1][1] = "Error: No valid user message provided"
29
+ yield history
30
+ return
31
+
32
+ # Load model if necessary
33
+ if model_handler.model is None or model_handler.tokenizer is None or model_id != model_handler.current_model_id:
34
+ status, _ = model_handler.load_model(model_id, history)
35
+ if "Error" in status:
36
+ logger.error(status)
37
+ history[-1][1] = status
38
+ yield history
39
+ return
40
+
41
+ torch.manual_seed(int(seed))
42
+ if torch.cuda.is_available():
43
+ torch.cuda.manual_seed(int(seed))
44
+ torch.cuda.manual_seed_all(int(seed))
45
+
46
+ # Validate prompt function
47
+ if model_id not in PROMPT_FUNCTIONS:
48
+ logger.error(f"No prompt function defined for model_id: {model_id}")
49
+ history[-1][1] = f"Error: No prompt function defined for model {model_id}"
50
+ yield history
51
+ return
52
+ prompt_fn = PROMPT_FUNCTIONS[model_id]
53
+
54
+ # Handle specific model prompt formatting
55
+ if model_id in [
56
+ "Llama-3.2-3B-Reasoning-Vi-Medical-LoRA",
57
+ "Qwen-3-0.6B-Reasoning-Vi-Medical-LoRA"
58
+ ]:
59
+ if auto_clear:
60
+ text = prompt_fn(model_handler.tokenizer, history[-1][0])
61
+ else:
62
+ text = prompt_fn(model_handler.tokenizer, merge_conversation(history))
63
+
64
+ inputs = model_handler.tokenizer(
65
+ [text],
66
+ return_tensors="pt",
67
+ padding=True,
68
+ truncation=True,
69
+ max_length=MAX_INPUT_TOKEN_LENGTH
70
+ )
71
+ else:
72
+ # Build conversation for other models
73
+ conversation = []
74
+ for msg in history:
75
+ if msg[0]:
76
+ conversation.append({"role": "user", "content": msg[0]})
77
+ if msg[1]:
78
+ clean_text = ' '.join(line for line in msg[1].split('\n') if not line.startswith('✅ Thought for')).strip()
79
+ conversation.append({"role": "assistant", "content": clean_text})
80
+ elif msg[0] and not msg[1]:
81
+ conversation.append({"role": "assistant", "content": ""})
82
+
83
+ # Ensure at least one user message
84
+ if not any(msg["role"] == "user" for msg in conversation):
85
+ logger.error("No valid user messages in conversation history")
86
+ history = raw_history
87
+ history[-1][1] = "Error: No valid user messages in conversation history"
88
+ yield history
89
+ return
90
+
91
+ # Apply auto_clear logic
92
+ if auto_clear:
93
+ # Keep only the last user message and add an empty assistant response
94
+ user_msgs = [msg for msg in conversation if msg["role"] == "user"]
95
+ if user_msgs:
96
+ conversation = [{"role": "user", "content": user_msgs[-1]["content"]}, {"role": "assistant", "content": ""}]
97
+ else:
98
+ logger.error("No user messages found after filtering")
99
+ history = raw_history
100
+ history[-1][1] = "Error: No user messages found in conversation history"
101
+ yield history
102
+ return
103
+ else:
104
+ # Ensure the conversation ends with an assistant placeholder if the last message is from user
105
+ if conversation and conversation[-1]["role"] == "user":
106
+ conversation.append({"role": "assistant", "content": ""})
107
+
108
+ text = prompt_fn(model_handler.tokenizer, conversation)
109
+ tokenizer_kwargs = {
110
+ "return_tensors": "pt",
111
+ "padding": True,
112
+ "truncation": True,
113
+ "max_length": MAX_INPUT_TOKEN_LENGTH
114
+ }
115
+
116
+ inputs = model_handler.tokenizer(text, **tokenizer_kwargs)
117
+
118
+ if inputs is None or "input_ids" not in inputs:
119
+ logger.error("Tokenizer returned invalid or None output")
120
+ history = raw_history
121
+ history[-1][1] = "Error: Failed to tokenize input"
122
+ yield history
123
+ return
124
+
125
+ input_ids = inputs["input_ids"].to(model_handler.model.device)
126
+ attention_mask = inputs.get("attention_mask").to(model_handler.model.device) if "attention_mask" in inputs else None
127
+
128
+ generate_kwargs = {
129
+ "input_ids": input_ids,
130
+ "attention_mask": attention_mask,
131
+ "max_new_tokens": max_tokens,
132
+ "do_sample": True,
133
+ "temperature": temperature,
134
+ "top_p": top_p,
135
+ "top_k": top_k,
136
+ "num_beams": 1,
137
+ "repetition_penalty": 1.0,
138
+ "pad_token_id": model_handler.tokenizer.pad_token_id,
139
+ "eos_token_id": model_handler.tokenizer.eos_token_id,
140
+ "use_cache": True,
141
+ "cache_implementation": "dynamic",
142
+ }
143
+
144
+ streamer = TextIteratorStreamer(model_handler.tokenizer, timeout=360.0, skip_prompt=True, skip_special_tokens=True)
145
+ generate_kwargs["streamer"] = streamer
146
+
147
+ def run_generation():
148
+ try:
149
+ model_handler.model.generate(**generate_kwargs)
150
+ except Exception as e:
151
+ logger.error(f"Generation failed: {str(e)}")
152
+ raise
153
+
154
+ thread = Thread(target=run_generation)
155
+ thread.start()
156
+
157
+ state = ParserState()
158
+ if model_id in [
159
+ "Llama-3.2-3B-Reasoning-Vi-Medical-LoRA",
160
+ "Qwen-3-0.6B-Reasoning-Vi-Medical-LoRA"
161
+ ]:
162
+ full_response = "<think>"
163
+ else:
164
+ full_response = ""
165
+
166
+ for text in streamer:
167
+ if not active_gen[0]:
168
+ logger.info("Generation stopped by user")
169
+ break
170
+
171
+ if text:
172
+ logger.debug(f"Raw streamer output: {text}")
173
+ text = re.sub(r'<\|\w+\|>', '', text)
174
+ full_response += text
175
+ state, elapsed = parse_response(full_response, state)
176
+
177
+ collapsible, answer_part = format_response(state, elapsed)
178
+ history = raw_history
179
+ history[-1][1] = "\n\n".join(collapsible + [answer_part])
180
+ yield history
181
+ else:
182
+ logger.debug("Streamer returned empty text")
183
+
184
+ thread.join()
185
+ thread = None
186
+ state, elapsed = parse_response(full_response, state)
187
+ collapsible, answer_part = format_response(state, elapsed)
188
+ history = raw_history
189
+ history[-1][1] = "\n\n".join(collapsible + [answer_part])
190
+
191
+ if not full_response:
192
+ logger.warning("No response generated by model")
193
+ history[-1][1] = "No response generated. Please try again or select a different model."
194
+
195
+ yield history
196
+
197
+ except Exception as e:
198
+ logger.error(f"Error in generate: {str(e)}")
199
+ history = raw_history
200
+ if not history or not isinstance(history, list):
201
+ history = [[None, f"Error: {str(e)}. Please try again or select a different model."]]
202
+ else:
203
+ history[-1][1] = f"Error: {str(e)}. Please try again or select a different model."
204
+
205
+ yield history
206
+ finally:
207
+ active_gen[0] = False
gradio_app/model_handler.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
+ from peft import PeftModel
4
+ import gc
5
+ from config import logger, LORA_CONFIGS
6
+
7
+ class ModelHandler:
8
+ def __init__(self):
9
+ self.model = None
10
+ self.tokenizer = None
11
+ self.current_model_id = None
12
+
13
+ def load_model(self, model_id, chatbot_state):
14
+ """Load the model, tokenizer, and apply LoRA adapter for the given model ID."""
15
+ try:
16
+ logger.info(f"Loading model: {model_id}")
17
+ print(f"Changing to model: {model_id}")
18
+ self.clear_model()
19
+
20
+ if model_id not in LORA_CONFIGS:
21
+ raise ValueError(f"Invalid model ID: {model_id}")
22
+
23
+ device = "cuda" if torch.cuda.is_available() else "cpu"
24
+ base_model_name = LORA_CONFIGS[model_id]["base_model"]
25
+ lora_adapter_name = LORA_CONFIGS[model_id]["lora_adapter"]
26
+
27
+ self.tokenizer = AutoTokenizer.from_pretrained(
28
+ base_model_name,
29
+ trust_remote_code=True
30
+ )
31
+ self.tokenizer.use_default_system_prompt = False
32
+
33
+ if self.tokenizer.pad_token is None or self.tokenizer.pad_token == self.tokenizer.eos_token:
34
+ self.tokenizer.pad_token = self.tokenizer.unk_token or "<pad>"
35
+ logger.info(f"Set pad_token to {self.tokenizer.pad_token}")
36
+
37
+ self.model = AutoModelForCausalLM.from_pretrained(
38
+ base_model_name,
39
+ torch_dtype=torch.float16,
40
+ device_map=device,
41
+ trust_remote_code=True
42
+ )
43
+
44
+ self.model = PeftModel.from_pretrained(self.model, lora_adapter_name)
45
+ self.model.eval()
46
+ self.model.config.pad_token_id = self.tokenizer.pad_token_id
47
+
48
+ self.current_model_id = model_id
49
+ chatbot_state = []
50
+ return f"Successfully loaded model: {model_id} with LoRA adapter {lora_adapter_name}", chatbot_state
51
+ except Exception as e:
52
+ logger.error(f"Failed to load model or tokenizer: {str(e)}")
53
+ return f"Error: Failed to load model {model_id}: {str(e)}", chatbot_state
54
+
55
+ def clear_model(self):
56
+ """Clear the current model and tokenizer from memory."""
57
+ if self.model is not None:
58
+ print("Clearing previous model from RAM/VRAM...")
59
+ del self.model
60
+ del self.tokenizer
61
+ self.model = None
62
+ self.tokenizer = None
63
+ gc.collect()
64
+ torch.cuda.empty_cache() if torch.cuda.is_available() else None
65
+ print("Memory cleared successfully.")
gradio_app/prompts.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from config import logger
2
+
3
+ def gemma_3_1b_instruct_vi_medical_lora(tokenizer, messages):
4
+ """Prompt style for Gemma-3-1B-Instruct-Vi-Medical-LoRA: Simple user prompt with chat template"""
5
+ return tokenizer.apply_chat_template(
6
+ messages,
7
+ add_generation_prompt=True,
8
+ tokenize=False
9
+ )
10
+
11
+ def gemma_3_1b_grpo_vi_medical_lora(tokenizer, messages):
12
+ """Prompt style for Gemma-3-1B-GRPO-Vi-Medical-LoRA: System prompt with reasoning and answer format"""
13
+ SYSTEM_PROMPT = """
14
+ Trả lời theo định dạng sau đây:
15
+ <reasoning>
16
+ ...
17
+ </reasoning>
18
+ <answer>
19
+ ...
20
+ </answer>
21
+ """
22
+ if not messages or not isinstance(messages, list) or not messages[0].get("role") == "user":
23
+ return tokenizer.apply_chat_template(
24
+ [{"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": "Vui lòng cung cấp câu hỏi để tôi trả lời."}],
25
+ add_generation_prompt=True,
26
+ tokenize=False
27
+ )
28
+
29
+ conversation = [{"role": "system", "content": SYSTEM_PROMPT}]
30
+ for i, msg in enumerate(messages):
31
+ conversation.append(msg)
32
+ if msg["role"] == "user" and (i == len(messages) - 1 or messages[i + 1]["role"] != "assistant"):
33
+ conversation.append({"role": "assistant", "content": ""})
34
+
35
+ return tokenizer.apply_chat_template(
36
+ conversation,
37
+ add_generation_prompt=True,
38
+ tokenize=False
39
+ )
40
+
41
+ def llama_3_2_3b_instruct_vi_medical_lora(tokenizer, messages):
42
+ """Prompt style for Llama-3.2-3B-Instruct-Vi-Medical-LoRA: Extract answer from context"""
43
+ instruction = '''Bạn là một trợ lý hữu ích được giao nhiệm vụ trích xuất các đoạn văn trả lời câu hỏi của người dùng từ một ngữ cảnh cho trước. Xuất ra các đoạn văn chính xác từng từ một trả lời câu hỏi của người dùng. Không xuất ra bất kỳ văn bản nào khác ngoài các đoạn văn trong ngữ cảnh. Xuất ra lượng tối thiểu để trả lời câu hỏi, ví dụ chỉ 2-3 từ từ đoạn văn. Nếu không thể tìm thấy câu trả lời trong ngữ cảnh, xuất ra 'Ngữ cảnh không cung cấp câu trả lời...' '''
44
+ return tokenizer.apply_chat_template(
45
+ [{"role": "system", "content": instruction}] + messages,
46
+ add_generation_prompt=True,
47
+ tokenize=False
48
+ )
49
+
50
+ def llama_3_2_1b_instruct_vi_medical_lora(tokenizer, messages):
51
+ """Prompt style for Llama-3.2-1B-Instruct-Vi-Medical-LoRA: Extract answer from context"""
52
+ return llama_3_2_3b_instruct_vi_medical_lora(tokenizer, messages)
53
+
54
+ def llama_3_2_3b_reasoning_vi_medical_lora(tokenizer, question):
55
+ """Prompt style for Llama-3.2-3B-Reasoning-Vi-Medical-LoRA: Reasoning prompt with think tag"""
56
+ inference_prompt_style = """Bên dưới là một hướng dẫn mô tả một tác vụ, đi kèm với một thông tin đầu vào để cung cấp thêm ngữ cảnh.
57
+ Hãy viết một phản hồi để hoàn thành yêu cầu một cách phù hợp.
58
+ Trước khi trả lời, hãy suy nghĩ cẩn thận về câu hỏi và tạo một chuỗi suy nghĩ từng bước để đảm bảo phản hồi logic và chính xác.
59
+
60
+ ### Instruction:
61
+ Bạn là một chuyên gia y tế có kiến thức chuyên sâu về lập luận lâm sàng, chẩn đoán và lập kế hoạch điều trị.
62
+ Vui lòng trả lời câu hỏi y tế sau đây.
63
+
64
+ ### Question:
65
+ {}
66
+
67
+ ### Response:
68
+ <think>
69
+ """
70
+ return inference_prompt_style.format(question) + tokenizer.eos_token
71
+
72
+ def qwen_3_0_6b_instruct_vi_medical_lora(tokenizer, messages):
73
+ """Prompt style for Qwen-3-0.6B-Instruct-Vi-Medical-LoRA: Qwen-specific with enable_thinking=False"""
74
+ return tokenizer.apply_chat_template(
75
+ messages,
76
+ add_generation_prompt=True,
77
+ tokenize=False,
78
+ enable_thinking=False
79
+ )
80
+
81
+ def qwen_3_0_6b_reasoning_vi_medical_lora(tokenizer, question):
82
+ """Prompt style for Qwen-3-0.6B-Reasoning-Vi-Medical-LoRA: Same as Llama-3.2-3B-Reasoning-Vi-Medical-LoRA"""
83
+ return llama_3_2_3b_reasoning_vi_medical_lora(tokenizer, question)
84
+
85
+ PROMPT_FUNCTIONS = {
86
+ "Gemma-3-1B-Instruct-Vi-Medical-LoRA": gemma_3_1b_instruct_vi_medical_lora,
87
+ "Gemma-3-1B-GRPO-Vi-Medical-LoRA": gemma_3_1b_grpo_vi_medical_lora,
88
+ "Llama-3.2-3B-Instruct-Vi-Medical-LoRA": llama_3_2_3b_instruct_vi_medical_lora,
89
+ "Llama-3.2-1B-Instruct-Vi-Medical-LoRA": llama_3_2_1b_instruct_vi_medical_lora,
90
+ "Llama-3.2-3B-Reasoning-Vi-Medical-LoRA": llama_3_2_3b_reasoning_vi_medical_lora,
91
+ "Qwen-3-0.6B-Instruct-Vi-Medical-LoRA": qwen_3_0_6b_instruct_vi_medical_lora,
92
+ "Qwen-3-0.6B-Reasoning-Vi-Medical-LoRA": qwen_3_0_6b_reasoning_vi_medical_lora
93
+ }
gradio_app/response_parser.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import time
3
+ from config import logger
4
+
5
+ class ParserState:
6
+ __slots__ = ['answer', 'thought', 'in_think', 'in_answer', 'start_time', 'last_pos', 'total_think_time']
7
+ def __init__(self):
8
+ self.answer = ""
9
+ self.thought = ""
10
+ self.in_think = False
11
+ self.in_answer = False
12
+ self.start_time = 0
13
+ self.last_pos = 0
14
+ self.total_think_time = 0.0
15
+
16
+ def format_time(seconds_float):
17
+ total_seconds = int(round(seconds_float))
18
+ hours = total_seconds // 3600
19
+ remaining_seconds = total_seconds % 3600
20
+ minutes = remaining_seconds // 60
21
+ seconds = remaining_seconds % 60
22
+
23
+ if hours > 0:
24
+ return f"{hours}h {minutes}m {seconds}s"
25
+ elif minutes > 0:
26
+ return f"{minutes}m {seconds}s"
27
+ else:
28
+ return f"{seconds}s"
29
+
30
+ def parse_response(text, state):
31
+ buffer = text[state.last_pos:]
32
+ state.last_pos = len(text)
33
+
34
+ while buffer:
35
+ if not state.in_think and not state.in_answer:
36
+ think_start = buffer.find('<think>')
37
+ reasoning_start = buffer.find('<reasoning>')
38
+ answer_start = buffer.find('<answer>')
39
+
40
+ starts = []
41
+ if think_start != -1:
42
+ starts.append((think_start, '<think>', 7, 'think'))
43
+ if reasoning_start != -1:
44
+ starts.append((reasoning_start, '<reasoning>', 11, 'think'))
45
+ if answer_start != -1:
46
+ starts.append((answer_start, '<answer>', 8, 'answer'))
47
+
48
+ if not starts:
49
+ state.answer += buffer
50
+ break
51
+
52
+ start_pos, start_tag, tag_length, mode = min(starts, key=lambda x: x[0])
53
+
54
+ state.answer += buffer[:start_pos]
55
+ if mode == 'think':
56
+ state.in_think = True
57
+ state.start_time = time.perf_counter()
58
+ else:
59
+ state.in_answer = True
60
+ buffer = buffer[start_pos + tag_length:]
61
+
62
+ elif state.in_think:
63
+ think_end = buffer.find('</think>')
64
+ reasoning_end = buffer.find('</reasoning>')
65
+
66
+ ends = []
67
+ if think_end != -1:
68
+ ends.append((think_end, '</think>', 8))
69
+ if reasoning_end != -1:
70
+ ends.append((reasoning_end, '</reasoning>', 12))
71
+
72
+ if ends:
73
+ end_pos, end_tag, tag_length = min(ends, key=lambda x: x[0])
74
+ state.thought += buffer[:end_pos]
75
+ duration = time.perf_counter() - state.start_time
76
+ state.total_think_time += duration
77
+ state.in_think = False
78
+ buffer = buffer[end_pos + tag_length:]
79
+ if end_tag == '</reasoning>':
80
+ state.answer += buffer
81
+ break
82
+ else:
83
+ state.thought += buffer
84
+ break
85
+
86
+ elif state.in_answer:
87
+ answer_end = buffer.find('</answer>')
88
+ if answer_end != -1:
89
+ state.answer += buffer[:answer_end]
90
+ state.in_answer = False
91
+ buffer = buffer[answer_end + 9:]
92
+ else:
93
+ state.answer += buffer
94
+ break
95
+
96
+ elapsed = time.perf_counter() - state.start_time if state.in_think else 0
97
+ return state, elapsed
98
+
99
+ def format_response(state, elapsed):
100
+ answer_part = state.answer
101
+ collapsible = []
102
+ collapsed = "<details open>"
103
+
104
+ if state.thought or state.in_think:
105
+ if state.in_think:
106
+ total_elapsed = state.total_think_time + elapsed
107
+ formatted_time = format_time(total_elapsed)
108
+ status = f"💭 Thinking for {formatted_time}"
109
+ else:
110
+ formatted_time = format_time(state.total_think_time)
111
+ status = f"✅ Thought for {formatted_time}"
112
+ collapsed = "<details>"
113
+ collapsible.append(
114
+ f"{collapsed}<summary>{status}</summary>\n\n<div class='thinking-container'>\n{state.thought}\n</div>\n</details>"
115
+ )
116
+ return collapsible, answer_part
117
+
118
+ def remove_tags(text):
119
+ if text is None:
120
+ return None
121
+ return re.sub(r'<[^>]+>', ' ', text).strip()
gradio_app/static/script.js ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ function copyToClipboard(elementId) {
2
+ const element = document.getElementById(elementId);
3
+ let text = element.innerText.replace(/^Thinking Process:\\n|^Final Answer:\\n/, '');
4
+ text = text.replace(/\\mjx-[^\\s]+/g, '');
5
+ navigator.clipboard.writeText(text).then(() => {
6
+ alert('Copied to clipboard!');
7
+ }).catch(err => {
8
+ console.error('Failed to copy: ', err);
9
+ });
10
+ }
gradio_app/static/styles.css ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .intro-container {
2
+ max-width: 800px;
3
+ padding: 40px;
4
+ background: #ffffff;
5
+ border-radius: 15px;
6
+ box-shadow: 0 4px 20px rgba(0, 0, 0, 0.1);
7
+ text-align: center;
8
+ animation: fadeIn 1s ease-in-out;
9
+ }
10
+ h1 {
11
+ font-size: 1.5em;
12
+ color: #007bff;
13
+ text-transform: uppercase;
14
+ letter-spacing: 1px;
15
+ margin-bottom: 20px;
16
+ }
17
+ h2 {
18
+ font-size: 1.3em;
19
+ color: #555555;
20
+ margin-bottom: 30px;
21
+ }
22
+ .intro-highlight {
23
+ font-size: 1.5em;
24
+ color: #333333;
25
+ margin: 20px 0;
26
+ padding: 20px;
27
+ background: #f8f9fa;
28
+ border-left: 5px solid #007bff;
29
+ border-radius: 10px;
30
+ transition: transform 0.3s ease;
31
+ }
32
+ .intro-highlight:hover {
33
+ transform: scale(1.02);
34
+ }
35
+ .intro-disclaimer {
36
+ font-size: 1.3em;
37
+ color: #333333;
38
+ background: #e9ecef;
39
+ padding: 20px;
40
+ border-radius: 10px;
41
+ border: 1px solid #007bff;
42
+ margin-top: 30px;
43
+ }
44
+ strong {
45
+ color: #007bff;
46
+ font-weight: bold;
47
+ }
48
+ .intro-icon {
49
+ font-size: 1.4em;
50
+ margin-right: 8px;
51
+ }
52
+ @keyframes fadeIn {
53
+ 0% { opacity: 0; transform: translateY(-20px); }
54
+ 100% { opacity: 1; transform: translateY(0); }
55
+ }
56
+
57
+ .spinner {
58
+ animation: spin 1s linear infinite;
59
+ display: inline-block;
60
+ margin-right: 8px;
61
+ }
62
+ @keyframes spin {
63
+ from { transform: rotate(0deg); }
64
+ to { transform: rotate(360deg); }
65
+ }
66
+ .thinking-summary {
67
+ cursor: pointer;
68
+ padding: 8px;
69
+ background: #f5f5f5;
70
+ border-radius: 4px;
71
+ margin: 4px 0;
72
+ }
73
+ .thought-content {
74
+ padding: 10px;
75
+ background: none;
76
+ border-radius: 4px;
77
+ margin: 5px 0;
78
+ }
79
+ .thinking-container {
80
+ border-left: 3px solid #facc15;
81
+ padding-left: 10px;
82
+ margin: 8px 0;
83
+ background: none;
84
+ }
85
+ .thinking-container:empty {
86
+ background: #e0e0e0;
87
+ }
88
+ details:not([open]) .thinking-container {
89
+ border-left-color: #290c15;
90
+ }
91
+ details {
92
+ border: 1px solid #e0e0e0 !important;
93
+ border-radius: 8px !important;
94
+ padding: 12px !important;
95
+ margin: 8px 0 !important;
96
+ transition: border-color 0.2s;
97
+ }
98
+ .think-section {
99
+ background-color: #e6f3ff;
100
+ border-left: 4px solid #4a90e2;
101
+ padding: 15px;
102
+ margin: 10px 0;
103
+ border-radius: 6px;
104
+ font-size: 14px;
105
+ }
106
+ .final-answer {
107
+ background-color: #f0f4f8;
108
+ border-left: 4px solid #2ecc71;
109
+ padding: 15px;
110
+ margin: 10px 0;
111
+ border-radius: 6px;
112
+ font-size: 14px;
113
+ }
114
+ #output-container {
115
+ position: relative;
116
+ }
117
+ .copy-button {
118
+ position: absolute;
119
+ top: 10px;
120
+ right: 10px;
121
+ padding: 5px 10px;
122
+ background-color: #4a90e2;
123
+ color: white;
124
+ border: none;
125
+ border-radius: 4px;
126
+ cursor: pointer;
127
+ }
128
+ .copy-button:hover {
129
+ background-color: #357abd;
130
+ }
131
+ .chatbot .message.assistant {
132
+ position: relative;
133
+ }
134
+ .chatbot .message.assistant::after {
135
+ content: 'Copy';
136
+ position: absolute;
137
+ top: 10px;
138
+ right: 10px;
139
+ padding: 5px 10px;
140
+ background-color: #4a90e2;
141
+ color: white;
142
+ border: none;
143
+ border-radius: 4px;
144
+ cursor: pointer;
145
+ }
146
+ .chatbot .message.assistant:hover::after {
147
+ background-color: #357abd;
148
+ }
gradio_app/utils.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ def merge_conversation(conversation):
2
+ valid_pairs = [(q, a) for q, a in conversation if a is not None]
3
+ formatted_pairs = [f"{q} {a}." for q, a in valid_pairs]
4
+ result = ["-".join(formatted_pairs) + "\n"]
5
+ return result
requirements.txt CHANGED
@@ -1,5 +1,5 @@
1
- gradio
2
- transformers
3
- torch
4
- accelerate
5
- peft
 
1
+ transformers=4.51.3
2
+ torch==2.7.0
3
+ accelerate==1.7.0
4
+ peft==0.14.0
5
+ gradio