File size: 5,392 Bytes
e1ca7c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c97ce95
6430bee
 
c97ce95
6430bee
a1cad33
e1ca7c7
 
 
 
 
 
ec3b10f
e1ca7c7
6430bee
e1ca7c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import os
import sys
import gradio as gr

sys.path.append(os.path.join(os.path.dirname(__file__), 'gradio_app'))

from config import logger, MODEL_IDS
from model_handler import ModelHandler
from generator import generate_response

DESCRIPTION = '''
<h1><span class="intro-icon">⚕️</span> Medical Chatbot with LoRA Models</h1>
    <h2>AI-Powered Medical Insights</h2>
    <div class="intro-highlight">
        <strong>Explore our advanced models, fine-tuned with LoRA for medical reasoning in Vietnamese.</strong>
    </div>
    <div class="intro-disclaimer">
        <strong><span class="intro-icon">ℹ️</span> Notice:</strong> For research purposes only. AI responses may have limitations due to development, datasets, and architecture. <strong>Always consult a medical professional for health advice 🩺</strong>.
</div>
'''

# Replace external CSS fetch with local file
CSS = open("gradio_app/static/styles.css").read()

# JS_PATH = "gradio_app/static/script.js"

def user(message, history):
    if not isinstance(history, list):
        history = []
    return "", history + [[message, None]]

def create_ui(model_handler):
    with gr.Blocks(css=CSS, theme=gr.themes.Default()) as demo:
        gr.Markdown(DESCRIPTION)
        gr.HTML('<script src="file=gradio_app/static/script.js"></script>')
        active_gen = gr.State([False])
        model_handler_state = gr.State(model_handler)
        
        chatbot = gr.Chatbot(
            elem_id="chatbot",
            height=500,
            show_label=False,
            render_markdown=True
        )

        with gr.Row():
            msg = gr.Textbox(
                label="Message",
                placeholder="Type your medical query in Vietnamese...",
                container=False,
                scale=4
            )
            submit_btn = gr.Button("Send", variant='primary', scale=1)
        
        with gr.Column(scale=2):
            with gr.Row():
                clear_btn = gr.Button("Clear", variant='secondary')
                stop_btn = gr.Button("Stop", variant='stop')
            
            with gr.Accordion("Parameters", open=False):
                model_dropdown = gr.Dropdown(
                    choices=MODEL_IDS,
                    value=MODEL_IDS[0],
                    label="Select Model",
                    interactive=True
                )
                temperature = gr.Slider(minimum=0.1, maximum=1.5, value=0.7, label="Temperature")
                top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.95, label="Top-p")
                top_k = gr.Slider(minimum=1, maximum=100, value=64, step=1, label="Top-k")
                max_tokens = gr.Slider(minimum=128, maximum=4084, value=512, step=32, label="Max Tokens")
                seed = gr.Slider(minimum=0, maximum=2**32, value=42, step=1, label="Random Seed")
                auto_clear = gr.Checkbox(label="Auto Clear History", value=True, 
                                         info="Clears internal conversation history after each response but keeps displayed previous messages.")

        gr.Examples(
            examples=[
                ["Khi nghi ngờ bị loét dạ dày tá tràng nên đến khoa nào tại bệnh viện để thăm khám?"],
                ["Triệu chứng của loét dạ dày tá tràng là gì?"],
                ["Tôi bị mất ngủ, tôi phải làm gì?"],
                ["Tôi bị trĩ, tôi có nên mổ không?"]
            ],
            inputs=msg,
            label="Example Medical Queries"
        )
        
        model_load_output = gr.Textbox(label="Model Load Status")
        model_dropdown.change(
            fn=model_handler.load_model,
            inputs=[model_dropdown, chatbot],
            outputs=[model_load_output, chatbot]
        )
        
        submit_event = submit_btn.click(
            fn=user,
            inputs=[msg, chatbot],
            outputs=[msg, chatbot],
            queue=False
        ).then(
            fn=lambda: [True],
            outputs=active_gen
        ).then(
            fn=generate_response,
            inputs=[model_handler_state, chatbot, temperature, top_p, top_k, max_tokens, seed, active_gen, model_dropdown, auto_clear],
            outputs=chatbot
        )
        
        msg.submit(
            fn=user,
            inputs=[msg, chatbot],
            outputs=[msg, chatbot],
            queue=False
        ).then(
            fn=lambda: [True],
            outputs=active_gen
        ).then(
            fn=generate_response,
            inputs=[model_handler_state, chatbot, temperature, top_p, top_k, max_tokens, seed, active_gen, model_dropdown, auto_clear],
            outputs=chatbot
        )
        
        stop_btn.click(
            fn=lambda: [False],
            inputs=None,
            outputs=active_gen,
            cancels=[submit_event]
        )
        
        clear_btn.click(
            fn=lambda: None,
            inputs=None,
            outputs=chatbot,
            queue=False
        )
    
    return demo

def main():
    model_handler = ModelHandler()
    model_handler.load_model(MODEL_IDS[0], [])
    demo = create_ui(model_handler)
    try:
        demo.launch(server_name="0.0.0.0", server_port=7860)
    except Exception as e:
        logger.error(f"Failed to launch Gradio app: {str(e)}")
        raise

if __name__ == "__main__":
    main()