yuyu061611 commited on
Commit
bfc371b
·
1 Parent(s): f641629
Files changed (2) hide show
  1. app.py +123 -0
  2. requirement.txt +14 -0
app.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import numpy as np
4
+ import torch
5
+ import torch.backends.cudnn as cudnn
6
+ import gradio as gr
7
+
8
+
9
+ os.system('git lfs install')
10
+ os.system("pip install git+https://hf-mirror.com/Pluto0616/L2_InternVL")
11
+
12
+ from utils import load_json, init_logger
13
+ from demo import ConversationalAgent, CustomTheme
14
+
15
+ FOOD_EXAMPLES = "/Pluto0616/L2_InternVL/demo/food_for_demo.json"
16
+ # MODEL_PATH = "/root/share/new_models/OpenGVLab/InternVL2-2B"
17
+ MODEL_PATH = "/Pluto0616/L2_InternVL/work_dirs/internvl_v2_internlm2_2b_lora_finetune_food/lr35_ep10"
18
+ OUTPUT_PATH = "./outputs"
19
+
20
+ def setup_seeds():
21
+ seed = 42
22
+
23
+ random.seed(seed)
24
+ np.random.seed(seed)
25
+ torch.manual_seed(seed)
26
+
27
+ cudnn.benchmark = False
28
+ cudnn.deterministic = True
29
+
30
+
31
+ def main():
32
+ setup_seeds()
33
+ # logging
34
+ init_logger(OUTPUT_PATH)
35
+ # food examples
36
+ food_examples = load_json(FOOD_EXAMPLES)
37
+
38
+ agent = ConversationalAgent(model_path=MODEL_PATH,
39
+ outputs_dir=OUTPUT_PATH)
40
+
41
+ theme = CustomTheme()
42
+
43
+ titles = [
44
+ """<center><B><font face="Comic Sans MS" size=10>书生大模型实战营</font></B></center>""" ## Kalam:wght@700
45
+ """<center><B><font face="Courier" size=5>「进阶岛」InternVL 多模态模型部署微调实践</font></B></center>"""
46
+ ]
47
+
48
+ language = """Language: 中文 and English"""
49
+ with gr.Blocks(theme) as demo_chatbot:
50
+ for title in titles:
51
+ gr.Markdown(title)
52
+ # gr.Markdown(article)
53
+ gr.Markdown(language)
54
+
55
+ with gr.Row():
56
+ with gr.Column(scale=3):
57
+ start_btn = gr.Button("Start Chat", variant="primary", interactive=True)
58
+ clear_btn = gr.Button("Clear Context", interactive=False)
59
+ image = gr.Image(type="pil", interactive=False)
60
+ upload_btn = gr.Button("🖼️ Upload Image", interactive=False)
61
+
62
+ with gr.Accordion("Generation Settings"):
63
+ top_p = gr.Slider(minimum=0, maximum=1, step=0.1,
64
+ value=0.8,
65
+ interactive=True,
66
+ label='top-p value',
67
+ visible=True)
68
+
69
+ temperature = gr.Slider(minimum=0, maximum=1.5, step=0.1,
70
+ value=0.8,
71
+ interactive=True,
72
+ label='temperature',
73
+ visible=True)
74
+
75
+ with gr.Column(scale=7):
76
+ chat_state = gr.State()
77
+ chatbot = gr.Chatbot(label='InternVL2', height=800, avatar_images=((os.path.join(os.path.dirname(__file__), 'demo/user.png')), (os.path.join(os.path.dirname(__file__), "demo/bot.png"))))
78
+ text_input = gr.Textbox(label='User', placeholder="Please click the <Start Chat> button to start chat!", interactive=False)
79
+ gr.Markdown("### 输入示例")
80
+ def on_text_change(text):
81
+ return gr.update(interactive=True)
82
+ text_input.change(fn=on_text_change, inputs=text_input, outputs=text_input)
83
+ gr.Examples(
84
+ examples=[["图片中的食物通常属于哪个菜系?"],
85
+ ["如果让你简单形容一下品尝图片中的食物的滋味,你会描述它"],
86
+ ["去哪个地方游玩时应该品尝当地的特色美食图片中的食物?"],
87
+ ["食用图片中的食物时,一般它上菜或摆盘时的特点是?"]],
88
+ inputs=[text_input]
89
+ )
90
+
91
+ with gr.Row():
92
+ gr.Markdown("### 食物快捷栏")
93
+ with gr.Row():
94
+ example_xinjiang_food = gr.Examples(examples=food_examples["新疆菜"], inputs=image, label="新疆菜")
95
+ example_sichuan_food = gr.Examples(examples=food_examples["川菜(四川,重庆)"], inputs=image, label="川菜(四川,重庆)")
96
+ example_xibei_food = gr.Examples(examples=food_examples["西北菜 (陕西,甘肃等地)"], inputs=image, label="西北菜 (陕西,甘肃等地)")
97
+ with gr.Row():
98
+ example_guizhou_food = gr.Examples(examples=food_examples["黔菜 (贵州)"], inputs=image, label="黔菜 (贵州)")
99
+ example_jiangsu_food = gr.Examples(examples=food_examples["苏菜(江苏)"], inputs=image, label="苏菜(江苏)")
100
+ example_guangdong_food = gr.Examples(examples=food_examples["粤菜(广东等地)"], inputs=image, label="粤菜(广东等地)")
101
+ with gr.Row():
102
+ example_hunan_food = gr.Examples(examples=food_examples["湘菜(湖南)"], inputs=image, label="湘菜(湖南)")
103
+ example_fujian_food = gr.Examples(examples=food_examples["闽菜(福建)"], inputs=image, label="闽菜(福建)")
104
+ example_zhejiang_food = gr.Examples(examples=food_examples["浙菜(浙江)"], inputs=image, label="浙菜(浙江)")
105
+ with gr.Row():
106
+ example_dongbei_food = gr.Examples(examples=food_examples["东北菜 (黑龙江等地)"], inputs=image, label="东北菜 (黑龙江等地)")
107
+
108
+
109
+ start_btn.click(agent.start_chat, [chat_state], [text_input, start_btn, clear_btn, image, upload_btn, chat_state])
110
+ clear_btn.click(agent.restart_chat, [chat_state], [chatbot, text_input, start_btn, clear_btn, image, upload_btn, chat_state], queue=False)
111
+ upload_btn.click(agent.upload_image, [image, chatbot, chat_state], [image, chatbot, chat_state])
112
+ text_input.submit(
113
+ agent.respond,
114
+ inputs=[text_input, image, chatbot, top_p, temperature, chat_state],
115
+ outputs=[text_input, image, chatbot, chat_state]
116
+ )
117
+
118
+ demo_chatbot.launch(share=True, server_name="127.0.0.1", server_port=1096, allowed_paths=['./'])
119
+ demo_chatbot.queue()
120
+
121
+
122
+ if __name__ == "__main__":
123
+ main()
requirement.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ lmdeploy==0.6.1
2
+ gradio==4.44.1
3
+ timm==1.0.9
4
+ xtuner==0.1.23 timm==1.0.9
5
+ 'xtuner[deepspeed]'
6
+ torch==2.4.1
7
+ torchvision==0.19.1
8
+ torchaudio==2.4.1 --index-url https://download.pytorch.org/whl/cu121
9
+ transformers==4.39.0
10
+ tokenizers==0.15.2
11
+ peft==0.13.2
12
+ datasets==3.1.0
13
+ accelerate==1.2.0
14
+ huggingface-hub==0.26.5