File size: 6,190 Bytes
9d54f50
4a28811
 
 
 
 
9d54f50
bfc371b
d0b1a03
b0b5576
 
e97aac5
4a28811
5742893
4a28811
5742893
 
4a28811
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import os
import random
import numpy as np
import torch
import torch.backends.cudnn as cudnn
import gradio as gr

os.system('git lfs install')
os.system("git clone https://huggingface.co/Pluto0616/L2_InternVL")

from L2_InternVL.utils import load_json, init_logger
from L2_InternVL.demo import ConversationalAgent, CustomTheme

FOOD_EXAMPLES = "./L2_InternVL/demo/food_for_demo.json"
# MODEL_PATH = "/root/share/new_models/OpenGVLab/InternVL2-2B"
MODEL_PATH = "./L2_InternVL/work_dirs/internvl_v2_internlm2_2b_lora_finetune_food/lr35_ep10"
OUTPUT_PATH = "./L2_InternVL/outputs"

def setup_seeds():
    seed = 42

    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

    cudnn.benchmark = False
    cudnn.deterministic = True


def main():
    setup_seeds()
    # logging
    init_logger(OUTPUT_PATH)
    # food examples
    food_examples = load_json(FOOD_EXAMPLES)
    
    agent = ConversationalAgent(model_path=MODEL_PATH,
                                outputs_dir=OUTPUT_PATH)
    
    theme = CustomTheme()
    
    titles = [
        """<center><B><font face="Comic Sans MS" size=10>书生大模型实战营</font></B></center>"""  ## Kalam:wght@700
        """<center><B><font face="Courier" size=5>「进阶岛」InternVL 多模态模型部署微调实践</font></B></center>"""
    ]
    
    language = """Language: 中文 and English"""
    with gr.Blocks(theme) as demo_chatbot:
        for title in titles:
            gr.Markdown(title)
        # gr.Markdown(article)
        gr.Markdown(language)
        
        with gr.Row():
            with gr.Column(scale=3):
                start_btn = gr.Button("Start Chat", variant="primary", interactive=True)
                clear_btn = gr.Button("Clear Context", interactive=False)
                image = gr.Image(type="pil", interactive=False)
                upload_btn = gr.Button("🖼️ Upload Image", interactive=False)
                
                with gr.Accordion("Generation Settings"):                    
                    top_p = gr.Slider(minimum=0, maximum=1, step=0.1,
                                      value=0.8,
                                      interactive=True,
                                      label='top-p value',
                                      visible=True)
                    
                    temperature = gr.Slider(minimum=0, maximum=1.5, step=0.1,
                                            value=0.8,
                                            interactive=True,
                                            label='temperature',
                                            visible=True)
                    
            with gr.Column(scale=7):
                chat_state = gr.State()
                chatbot = gr.Chatbot(label='InternVL2', height=800, avatar_images=((os.path.join(os.path.dirname(__file__), 'demo/user.png')), (os.path.join(os.path.dirname(__file__), "demo/bot.png"))))
                text_input = gr.Textbox(label='User', placeholder="Please click the <Start Chat> button to start chat!", interactive=False)
                gr.Markdown("### 输入示例")
                def on_text_change(text):
                    return gr.update(interactive=True)
                text_input.change(fn=on_text_change, inputs=text_input, outputs=text_input)
                gr.Examples(
                    examples=[["图片中的食物通常属于哪个菜系?"],
                              ["如果让你简单形容一下品尝图片中的食物的滋味,你会描述它"],
                              ["去哪个地方游玩时应该品尝当地的特色美食图片中的食物?"],
                              ["食用图片中的食物时,一般它上菜或摆盘时的特点是?"]],
                    inputs=[text_input]
                )
        
        with gr.Row():
            gr.Markdown("### 食物快捷栏")
        with gr.Row():
            example_xinjiang_food = gr.Examples(examples=food_examples["新疆菜"], inputs=image, label="新疆菜")
            example_sichuan_food = gr.Examples(examples=food_examples["川菜(四川,重庆)"], inputs=image, label="川菜(四川,重庆)")
            example_xibei_food = gr.Examples(examples=food_examples["西北菜 (陕西,甘肃等地)"], inputs=image, label="西北菜 (陕西,甘肃等地)")
        with gr.Row():
            example_guizhou_food = gr.Examples(examples=food_examples["黔菜 (贵州)"], inputs=image, label="黔菜 (贵州)")
            example_jiangsu_food = gr.Examples(examples=food_examples["苏菜(江苏)"], inputs=image, label="苏菜(江苏)")
            example_guangdong_food = gr.Examples(examples=food_examples["粤菜(广东等地)"], inputs=image, label="粤菜(广东等地)")
        with gr.Row():
            example_hunan_food = gr.Examples(examples=food_examples["湘菜(湖南)"], inputs=image, label="湘菜(湖南)")
            example_fujian_food = gr.Examples(examples=food_examples["闽菜(福建)"], inputs=image, label="闽菜(福建)")
            example_zhejiang_food = gr.Examples(examples=food_examples["浙菜(浙江)"], inputs=image, label="浙菜(浙江)")
        with gr.Row():
            example_dongbei_food = gr.Examples(examples=food_examples["东北菜 (黑龙江等地)"], inputs=image, label="东北菜 (黑龙江等地)")
            
                
        start_btn.click(agent.start_chat, [chat_state], [text_input, start_btn, clear_btn, image, upload_btn, chat_state])
        clear_btn.click(agent.restart_chat, [chat_state], [chatbot, text_input, start_btn, clear_btn, image, upload_btn, chat_state], queue=False)
        upload_btn.click(agent.upload_image, [image, chatbot, chat_state], [image, chatbot, chat_state])
        text_input.submit(
            agent.respond,
            inputs=[text_input, image, chatbot, top_p, temperature, chat_state], 
            outputs=[text_input, image, chatbot, chat_state]
        )

    demo_chatbot.launch(share=True, server_name="127.0.0.1", server_port=1096, allowed_paths=['./'])
    demo_chatbot.queue()
    

if __name__ == "__main__":
    main()