KaiChen1998 commited on
Commit
0fa20f6
·
1 Parent(s): 256f531

upload emova hf demo

Browse files
Files changed (4) hide show
  1. .gitignore +3 -0
  2. app.py +544 -4
  3. conversation_public.py +506 -0
  4. requirements.txt +29 -0
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ __pycache__/
2
+ speech/
3
+ examples/
app.py CHANGED
@@ -1,7 +1,547 @@
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
 
 
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
 
 
 
 
 
 
1
+ import argparse
2
+ import datetime
3
+ import json
4
+ import os
5
+ import time
6
+ import hashlib
7
+ import uuid
8
+
9
+ import spaces
10
  import gradio as gr
11
+ from conversation_public import default_conversation, conv_templates, SeparatorStyle
12
+
13
+ auth_token = os.environ.get("TOKEN_FROM_SECRET")
14
+
15
+ ##########################################
16
+ # Audio part
17
+ ##########################################
18
+ from huggingface_hub import snapshot_download
19
+ snapshot_download(repo_id="Emova-ollm/emova_speech_tokenizer", local_dir='./speech', token=auth_token)
20
+
21
+ from speech.speech_utils import s2u_extract_unit_demo, get_ckpt_config_path, load_model
22
+ from speech.speech_utils import load_condition_centroid, get_config_checkpoint_file, load_U2S_model, synthesis
23
+
24
+ ####################
25
+ # S2U
26
+ ####################
27
+ reduced=True
28
+ reduced_mark = 'reduced' if reduced else 'unreduced'
29
+ unit_type = '40ms_multilingual_8888'
30
+ language = 'English'
31
+ s2u_model_name = 'SPIRAL-FSQ-CTC'
32
+
33
+ ckpt_path, config_path = get_ckpt_config_path(unit_type, language)
34
+ s2u_model = load_model(ckpt_path, config_path, s2u_model_name)
35
+
36
+ ####################
37
+ # U2S
38
+ ####################
39
+ condition2style_centroid_file = "./speech/condition_style_centroid/condition2style_centroid.txt"
40
+ condition2style_centroid_file_dict, condition2style_centroid_embedding_dict = load_condition_centroid(condition2style_centroid_file)
41
+
42
+ unit_type = '40ms_multilingual_8888_xujing_cosyvoice_FT'
43
+ language = 'Chinese'
44
+ model_config_file, model_checkpoint_file = get_config_checkpoint_file(unit_type, language)
45
+ net_g, hps = load_U2S_model(model_config_file, model_checkpoint_file, unit_type)
46
+
47
+ ####################
48
+ # task format
49
+ ####################
50
+ asr_format = "Please recognize the text corresponding to the follwing speech.\n"
51
+ tts_format = "Please synthesize the speech corresponding to the follwing text.\n"
52
+ chat_format = r'Please recognize the texts, emotion and pitch from the user question speech units and provide the texts, emotion, pitch and speech units for the assistant response. \nEmotion should be chosen from ["neutral", "happy", "sad", "angry", "surprised", "disgusted", "fearful"]. \nPitch should be chosen from ["low", "normal", "high"].\nYour output should be in json format.\nAn output example is:\n{"user question text": "", "user question emotion": "", "user question pitch": "", "assistant response text": "", "assistant response emotion": "", "assistant response pitch": "","assistant response speech": ""}\n\nuser question speech:'
53
+
54
+ @spaces.GPU(duration=5)
55
+ def s2u_asr(text, audio_file):
56
+ return asr_format + s2u_extract_unit_demo(s2u_model, audio_file, model_name=s2u_model_name, reduced=reduced)
57
+
58
+ @spaces.GPU(duration=5)
59
+ def s2u_chat(text, audio_file):
60
+ return chat_format + s2u_extract_unit_demo(s2u_model, audio_file, model_name=s2u_model_name, reduced=reduced)
61
+
62
+ def u2s_tts(text, audio_file):
63
+ return tts_format + text
64
+
65
+ mode2func = dict(
66
+ asr=s2u_asr,
67
+ chat=s2u_chat,
68
+ tts=u2s_tts,
69
+ )
70
+
71
+ ##########################################
72
+ # LLM part
73
+ ##########################################
74
+ import torch
75
+ from transformers import AutoModel, AutoProcessor, TextIteratorStreamer
76
+ from threading import Thread
77
+
78
+ model_name = "Emova-ollm/emova_llama3_1-8b"
79
+ model = AutoModel.from_pretrained(
80
+ model_name,
81
+ torch_dtype=torch.bfloat16,
82
+ use_flash_attn=True,
83
+ low_cpu_mem_usage=True,
84
+ trust_remote_code=True,
85
+ token=auth_token).eval().cuda()
86
+ processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
87
+ streamer = TextIteratorStreamer(processor.tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15)
88
+
89
+ def stream_response(model, inputs, streamer, prompt, gen_kwargs):
90
+ thread = Thread(target=model.generate, kwargs=dict(
91
+ streamer=streamer,
92
+ **inputs,
93
+ **gen_kwargs
94
+ ))
95
+ thread.start()
96
+
97
+ generated_text = prompt
98
+ for new_text in streamer:
99
+ generated_text += new_text
100
+ yield generated_text
101
+
102
+ ##########################################
103
+ # Gradio part
104
+ ##########################################
105
+ no_change_btn = gr.Button()
106
+ enable_btn = gr.Button(interactive=True)
107
+ disable_btn = gr.Button(interactive=False)
108
+ server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
109
+
110
+ def load_demo_refresh_model_list():
111
+ print(f"load_demo.")
112
+ state = default_conversation.copy()
113
+ return state
114
+
115
+ def regenerate(state, image_process_mode):
116
+ print(f"regenerate.")
117
+ state.messages[-1][-1] = None
118
+ prev_human_msg = state.messages[-2]
119
+ if type(prev_human_msg[1]) in (tuple, list):
120
+ prev_human_msg[1] = (*prev_human_msg[1][:2], image_process_mode, *prev_human_msg[1][3:])
121
+ state.skip_next = False
122
+ return (state, state.to_gradio_chatbot_public(), "", None, None) + (disable_btn,) * 2
123
+
124
+ def clear_history():
125
+ print(f"clear_history.")
126
+ state = default_conversation.copy()
127
+ return (state, state.to_gradio_chatbot_public(), "", None) + (disable_btn,) * 2 + (None,)
128
+
129
+ ############
130
+ # Show prompt in the chatbot
131
+ # Input: [state, textbox, imagebox, image_process_mode, audio_input, audio_mode]
132
+ # Return: [state, chatbot, textbox, imagebox, audio_input] + btn_list
133
+ ############
134
+ def add_text(state, text, image, image_process_mode, audio_input, audio_mode):
135
+ ############
136
+ # Input legality checking
137
+ ############
138
+ print(f"add_text. len: {len(text)}")
139
+ if len(text) <= 0 and image is None and audio_input is None:
140
+ state.skip_next = True
141
+ return (state, state.to_gradio_chatbot_public(), "", None, None) + (no_change_btn,) * 2
142
+
143
+ ############
144
+ # Re-initialize if having conducted audio conversations
145
+ ############
146
+ for i, (role, msg) in enumerate(state.messages[state.offset:]):
147
+ if isinstance(msg, tuple) and msg[-1] is not None:
148
+ state = default_conversation.copy()
149
+ break
150
+
151
+ ############
152
+ # Deal with image inputs
153
+ ############
154
+ if image is not None:
155
+ if '<image>' not in text:
156
+ text = text + '\n<image>'
157
+ text = (text, image, image_process_mode, None)
158
+ state = default_conversation.copy()
159
+
160
+ ############
161
+ # Deal with audio inputs
162
+ ############
163
+ if audio_input is not None or audio_mode == 'tts':
164
+ if isinstance(text, tuple):
165
+ if audio_mode == 'chat':
166
+ prompt = mode2func[audio_mode](text[0][:-len("\n<image>")], audio_input)
167
+ text = (prompt + "\n<image>", text[1], text[2], audio_input)
168
+ elif audio_mode == 'tts':
169
+ prompt = mode2func[audio_mode](text[0][:-len("\n<image>")], audio_input)
170
+ text = (prompt, None, None, None)
171
+ else:
172
+ prompt = mode2func[audio_mode](text, audio_input)
173
+ text = (prompt, None, None, audio_input)
174
+ else:
175
+ prompt = mode2func[audio_mode](text, audio_input)
176
+ text = (prompt, None, None, audio_input)
177
+ state = default_conversation.copy()
178
+ state.append_message(state.roles[0], text)
179
+ state.append_message(state.roles[1], None)
180
+ state.skip_next = False
181
+ print(str(state.messages))
182
+ return (state, state.to_gradio_chatbot_public(), "", None, None) + (disable_btn,) * 2
183
+
184
+ ############
185
+ # Get response
186
+ # Input: [state, temperature, top_p, max_output_tokens, speaker]
187
+ # Return: [state, chatbot] + btn_list
188
+ ############
189
+ @spaces.GPU
190
+ def http_bot(state, temperature, top_p, max_new_tokens, speaker):
191
+ print(f"http_bot.")
192
+
193
+ if state.skip_next:
194
+ yield (state, state.to_gradio_chatbot_public()) + (no_change_btn,) * 2
195
+ return
196
+
197
+ if len(state.messages) == state.offset + 2:
198
+ # First round of conversation
199
+ if 'llama-2' in model_name.lower():
200
+ template_name = "llava_llama_2"
201
+ elif "mistral" in model_name.lower() or "mixtral" in model_name.lower():
202
+ if 'orca' in model_name.lower():
203
+ template_name = "mistral_orca"
204
+ elif 'hermes' in model_name.lower():
205
+ template_name = "chatml_direct"
206
+ else:
207
+ template_name = "mistral_instruct"
208
+ elif 'llava-v1.6-34b' in model_name.lower():
209
+ template_name = "chatml_direct"
210
+ elif "v1" in model_name.lower():
211
+ if 'mmtag' in model_name.lower():
212
+ template_name = "v1_mmtag"
213
+ elif 'plain' in model_name.lower() and 'finetune' not in model_name.lower():
214
+ template_name = "v1_mmtag"
215
+ else:
216
+ template_name = "llava_v1"
217
+ elif "mpt" in model_name.lower():
218
+ template_name = "mpt"
219
+ elif "llama3" in model_name.lower():
220
+ template_name = 'llama3_demo'
221
+ else:
222
+ if 'mmtag' in model_name.lower():
223
+ template_name = "v0_mmtag"
224
+ elif 'plain' in model_name.lower() and 'finetune' not in model_name.lower():
225
+ template_name = "v0_mmtag"
226
+ else:
227
+ template_name = "llava_v0"
228
+
229
+ new_state = conv_templates[template_name].copy()
230
+ new_state.append_message(new_state.roles[0], state.messages[-2][1])
231
+ new_state.append_message(new_state.roles[1], None)
232
+ state = new_state
233
+
234
+ # Construct prompt
235
+ prompt = state.get_prompt()
236
+ all_images = state.get_images(return_pil=True)
237
+ all_image_hash = [hashlib.md5(image.tobytes()).hexdigest() for image in all_images]
238
+
239
+ # Make requests
240
+ pload = {
241
+ "model": model_name,
242
+ "prompt": prompt,
243
+ "temperature": float(temperature),
244
+ "top_p": float(top_p),
245
+ "max_new_tokens": int(max_new_tokens),
246
+ "stop": state.sep if state.sep_style in [SeparatorStyle.SINGLE, SeparatorStyle.MPT] else state.sep2,
247
+ "images": f'List of {len(state.get_images())} images: {all_image_hash}',
248
+ }
249
+ print(f"==== request ====\n{pload}")
250
+ pload['images'] = all_images
251
+
252
+ # Process inputs
253
+ inputs = processor(text=[prompt], images=all_images if len(all_images) > 0 else None, return_tensors="pt")
254
+ inputs.to(model.device)
255
+ if len(all_images) > 0:
256
+ inputs['pixel_values'] = inputs['pixel_values'].to(model.dtype)
257
+
258
+ # Process hyperparameters
259
+ temperature = float(pload.get("temperature", 1.0))
260
+ top_p = float(pload.get("top_p", 1.0))
261
+ stop_str = pload.get("stop", None)
262
+ do_sample = True if temperature > 0.001 else False
263
+ max_context_length = getattr(model.config, 'max_position_embeddings', 2048)
264
+ max_new_tokens = int(pload.get("max_new_tokens", 256))
265
+ max_new_tokens = min(max_new_tokens, max_context_length - inputs['input_ids'].shape[1])
266
+ gen_kwargs = dict(
267
+ do_sample=do_sample,
268
+ temperature=temperature,
269
+ top_p=top_p,
270
+ max_new_tokens=max_new_tokens,
271
+ use_cache=True,
272
+ )
273
+
274
+ if max_new_tokens < 1:
275
+ state.messages[-1][-1] = "Exceeds max token length. Please start a new conversation, thanks."
276
+ yield (state, state.to_gradio_chatbot_public()) + (disable_btn,) * 2
277
+ return
278
+
279
+ state.messages[-1][-1] = "▌"
280
+ yield (state, state.to_gradio_chatbot_public()) + (disable_btn,) * 2
281
+
282
+ # Stream output
283
+ try:
284
+ for generated_text in stream_response(model, inputs, streamer, prompt, gen_kwargs):
285
+ output = generated_text[len(prompt):].strip()
286
+ if tts_format not in prompt and chat_format not in prompt:
287
+ state.messages[-1][-1] = output + "▌"
288
+ else:
289
+ state.messages[-1][-1] = "▌"
290
+ # state.messages[-1][-1] = "[😁 GENERATING AUDIO {}%...]".format(round(output.count("<|speech_") / max_new_tokens * 100, 1)) + "\n" + output + "▌"
291
+ yield (state, state.to_gradio_chatbot_public()) + (disable_btn,) * 2
292
+ except Exception as e:
293
+ state.messages[-1][-1] = server_error_msg
294
+ yield (state, state.to_gradio_chatbot_public()) + (enable_btn,) * 2
295
+ return
296
+
297
+ ################
298
+ # decode output to audio
299
+ ################
300
+ temp_file = None
301
+ if tts_format in prompt or chat_format in prompt:
302
+ try:
303
+ try:
304
+ if output.startswith("{"):
305
+ if output.endswith("|>"):
306
+ output += "\"}"
307
+ elif output.endswith("\""):
308
+ output += "}"
309
+ info_dict = json.loads(output)
310
+ content_unit = info_dict['assistant response speech'].replace('<|speech_', '').replace('|>', ' ').strip()
311
+ emotion = info_dict['assistant response emotion'] if hasattr(info_dict, 'assistant response emotion') else "neutral"
312
+ speed = info_dict['assistant response speed'] if hasattr(info_dict, 'assistant response speed') else "normal"
313
+ pitch = info_dict['assistant response pitch'] if hasattr(info_dict, 'assistant response pitch') else "normal"
314
+ gender = speaker.lower() if speaker else 'female'
315
+ except:
316
+ content_unit = output.replace('<|speech_', '').replace('|>', ' ').strip()
317
+ emotion = 'neutral'
318
+ speed = "normal"
319
+ pitch = "normal"
320
+ gender = speaker.lower() if speaker else 'female'
321
+
322
+ condition = f'gender-{gender}_emotion-{emotion}_speed-{speed}_pitch-{pitch}'
323
+ style_centroid_file = condition2style_centroid_file_dict[condition]
324
+ style_centroid_embedding = condition2style_centroid_embedding_dict[condition]
325
+ print(condition)
326
+
327
+ id = str(uuid.uuid4())
328
+ os.makedirs("./demo_audio", exist_ok=True)
329
+ synthesis(content_unit, style_centroid_embedding, hps, net_g, f"./demo_audio/{id}_temp_audio.wav")
330
+ temp_file = f"./demo_audio/{id}_temp_audio.wav"
331
+ except Exception as e:
332
+ print(e)
333
+
334
+ state.messages[-1][-1] = state.messages[-1][-1][:-1]
335
+ if tts_format in prompt or chat_format in prompt:
336
+ if temp_file is not None:
337
+ state.messages[-1][-1] = (output, temp_file)
338
+ yield (state, state.to_gradio_chatbot_public()) + (enable_btn,) * 2
339
+ else:
340
+ state.messages[-1][-1] = server_error_msg
341
+ yield (state, state.to_gradio_chatbot_public()) + (enable_btn,) * 2
342
+ else:
343
+ yield (state, state.to_gradio_chatbot_public()) + (enable_btn,) * 2
344
+
345
+ if temp_file is not None:
346
+ os.system("rm {}".format(temp_file))
347
+
348
+ print(f"{output}")
349
+
350
+ ############
351
+ # Layout Markdown
352
+ ############
353
+ title_markdown = ("""
354
+ <div style="display: flex; align-items: center; padding: 20px; border-radius: 10px; background-color: #f0f0f0;">
355
+ <div style="margin-right: 20px;">
356
+ <img src="https://emova-ollm.github.io/static/images/icons/emova.png" alt="Icon" style="width: 100px; height: 100px; border-radius: 10px;">
357
+ </div>
358
+ <div>
359
+ <h1 style="margin: 0;">EMOVA: Empowering Language Models to See, Hear and Speak with Vivid Emotion</h2>
360
+ <p style="margin: 10px 0;">
361
+ 1. Note that to use the Webcam and Microphone, open <a href="chrome://flags/#unsafely-treat-insecure-origin-as-secure">chrome://flags/#unsafely-treat-insecure-origin-as-secure</a> and put this link into the box.<br/>
362
+ 2. To chat with EMOVA, upload images, enter texts or record audios and then do not forget to <mark>Click 💬 Chat Button</mark> ^v^!<br/>
363
+ 3. Heighten the <code>Max output tokens</code> if necessary to talk longer with EMOVA.
364
+ </p>
365
+ </div>
366
+ </div>
367
+ """)
368
+
369
+ tos_markdown = ("""
370
+ ### Terms of use
371
+ By using this service, users are required to agree to the following terms:
372
+ The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The service may collect user dialogue data for future research.
373
+ For an optimal experience, please use desktop computers for this demo, as mobile devices may compromise its quality.
374
+ """)
375
+
376
+ learn_more_markdown = ("""
377
+ ### License
378
+ The service is a research preview intended for non-commercial use only, subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of LLaMA and [Privacy Practices](https://chrome.google.com/webstore/detail/sharegpt-share-your-chatg/daiacboceoaocpibfodeljbdfacokfjb) of ShareGPT. Please contact us if you find any potential violation.
379
+
380
+ ### Acknowledgement
381
+ The service is built upon [LLaVA](https://github.com/haotian-liu/LLaVA/). We thanks the authors for open-sourcing the wonderful code.
382
+ """)
383
+
384
+ block_css = """
385
+ #buttons button {
386
+ min-width: min(120px,100%);
387
+ }
388
+
389
+ .message-row img {
390
+ margin: 0px !important;
391
+ }
392
+
393
+ .avatar-container img {
394
+ padding: 0px !important;
395
+ }
396
+ """
397
+
398
+ ############
399
+ # Layout Demo
400
+ ############
401
+ def build_demo(embed_mode, cur_dir=None):
402
+ textbox = gr.Textbox(label="Text", show_label=False, placeholder="Enter text or record audio in the right and then click 💬 Chat to talk with me ^v^", container=False, scale=6)
403
+ audio_input = gr.Audio(label="Audio", sources=["microphone", "upload"], type="filepath", max_length=10, show_download_button=True, waveform_options=dict(sample_rate=16000), scale=2)
404
+ with gr.Blocks(title="EMOVA", theme=gr.themes.Default(), css=block_css) as demo:
405
+ state = gr.State()
406
+ if not embed_mode:
407
+ gr.Markdown(title_markdown)
408
+
409
+ ##############
410
+ # Chatbot
411
+ ##############
412
+ with gr.Row(equal_height=True):
413
+ with gr.Column(scale=1):
414
+ imagebox = gr.Image(type="pil", label="Image")
415
+ image_process_mode = gr.Radio(
416
+ ["Crop", "Resize", "Pad", "Default"],
417
+ value="Default",
418
+ label="Preprocess for non-square image", visible=False)
419
+
420
+ ##############
421
+ # Parameters
422
+ ##############
423
+ with gr.Accordion("Parameters", open=True) as parameter_row:
424
+ temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.2, step=0.1, interactive=True, label="Temperature")
425
+ top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, step=0.1, interactive=True, label="Top P")
426
+ max_output_tokens = gr.Slider(minimum=0, maximum=4096, value=2048, step=32, interactive=True, label="Max output tokens")
427
+ speaker = gr.Radio(["Female", "Male"], label="Speaker")
428
+
429
+ with gr.Column(scale=8):
430
+ chatbot = gr.Chatbot(
431
+ elem_id="chatbot",
432
+ label="EMOVA Chatbot",
433
+ layout="bubble",
434
+ avatar_images=["examples/user_avator.png", "examples/icon_256.png"]
435
+ )
436
+ with gr.Row(equal_height=True):
437
+ textbox.render()
438
+ audio_input.render()
439
+ with gr.Row(elem_id="buttons") as button_row:
440
+ submit_btn = gr.Button(value="💬 Chat", variant="primary")
441
+ #stop_btn = gr.Button(value="⏹️ Stop Generation", interactive=False)
442
+ regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False)
443
+ clear_btn = gr.Button(value="🗑️ Clear", interactive=False)
444
+
445
+ ##############
446
+ # Examples
447
+ ##############
448
+ if cur_dir is None:
449
+ cur_dir = os.path.dirname(os.path.abspath(__file__))
450
+
451
+ with gr.Row():
452
+ with gr.Column(scale=9):
453
+ gr.Examples(examples=[
454
+ [f"{cur_dir}/examples/emo-speech/what_is_your_name.wav"],
455
+ [f"{cur_dir}/examples/emo-speech/parent.wav"],
456
+ [f"{cur_dir}/examples/emo-speech/I_am_so_sad.wav"],
457
+ [f"{cur_dir}/examples/emo-speech/wedding(CH).wav"],
458
+ ], inputs=[audio_input], label='Audio Examples')
459
+
460
+ with gr.Row(equal_height=True):
461
+ gr.Examples(examples=[
462
+ [f"{cur_dir}/examples/image-text/example_1.png", "Why is this image funny?"],
463
+ [f"{cur_dir}/examples/image-text/example_2.png", "First please perform reasoning, and think step by step to provide best answer to the following question:\n\nWhat is the original price for pork belly before discount?"],
464
+ [f"{cur_dir}/examples/image-text/example_3.png", "Convert this table to markdown format."],
465
+ ], inputs=[imagebox, textbox], label='Image Examples')
466
+ gr.Examples(examples=[
467
+ [f"{cur_dir}/examples/emo-speech/write_a_poem.jfif", f"{cur_dir}/examples/emo-speech/write_a_poem.wav"],
468
+ [f"{cur_dir}/examples/emo-speech/I_am_happy_get_my_offer.webp", f"{cur_dir}/examples/emo-speech/I_am_happy_get_my_offer.wav"],
469
+ [f"{cur_dir}/examples/structure-speech/names_of_main_actors.jpg", f"{cur_dir}/examples/structure-speech/names_of_main_actors.wav"],
470
+ ], inputs=[imagebox, audio_input], label='Omni Examples 1')
471
+ gr.Examples(examples=[
472
+ [f"{cur_dir}/examples/structure-speech/how_to_save_water.png", f"{cur_dir}/examples/structure-speech/how_to_save_water.wav"],
473
+ [f"{cur_dir}/examples/structure-speech/internet_coverage.png", f"{cur_dir}/examples/structure-speech/internet_coverage.wav"],
474
+ [f"{cur_dir}/examples/structure-speech/how_to_use_website.PNG", f"{cur_dir}/examples/structure-speech/how_to_use_website.wav"],
475
+ ], inputs=[imagebox, audio_input], label='Omni Examples 2')
476
+
477
+ if not embed_mode:
478
+ gr.Markdown(tos_markdown)
479
+ gr.Markdown(learn_more_markdown)
480
+
481
+ # Register listeners
482
+ btn_list = [regenerate_btn, clear_btn]
483
+ regenerate_btn.click(
484
+ regenerate,
485
+ [state, image_process_mode],
486
+ [state, chatbot, textbox, imagebox, audio_input] + btn_list
487
+ ).then(
488
+ http_bot,
489
+ [state, temperature, top_p, max_output_tokens, speaker],
490
+ [state, chatbot] + btn_list,
491
+ )
492
+
493
+ clear_btn.click(
494
+ clear_history,
495
+ None,
496
+ [state, chatbot, textbox, imagebox] + btn_list + [audio_input],
497
+ queue=False
498
+ )
499
+
500
+ # probably mean press enter
501
+ textbox.submit(
502
+ add_text,
503
+ [state, textbox, imagebox, image_process_mode, audio_input, gr.Number(value='chat', visible=False)],
504
+ [state, chatbot, textbox, imagebox, audio_input] + btn_list,
505
+ queue=False
506
+ ).then(
507
+ http_bot,
508
+ [state, temperature, top_p, max_output_tokens, speaker],
509
+ [state, chatbot] + btn_list,
510
+ )
511
+
512
+ submit_btn.click(
513
+ add_text,
514
+ [state, textbox, imagebox, image_process_mode, audio_input, gr.Number(value='chat', visible=False)],
515
+ [state, chatbot, textbox, imagebox, audio_input] + btn_list
516
+ ).then(
517
+ http_bot,
518
+ [state, temperature, top_p, max_output_tokens, speaker],
519
+ [state, chatbot] + btn_list,
520
+ )
521
+
522
+ ##############
523
+ # Demo loading
524
+ ##############
525
+ demo.load(
526
+ load_demo_refresh_model_list,
527
+ None,
528
+ [state],
529
+ queue=False
530
+ )
531
+ return demo
532
+
533
 
534
+ if __name__ == "__main__":
535
+ parser = argparse.ArgumentParser()
536
+ parser.add_argument("--share", action="store_true")
537
+ parser.add_argument("--embed", action="store_true")
538
+ args = parser.parse_args()
539
 
540
+ demo = build_demo(args.embed)
541
+ demo.queue(
542
+ api_open=False
543
+ ).launch(
544
+ favicon_path="./examples/icon_256.png",
545
+ allowed_paths=["/"],
546
+ share=args.share
547
+ )
conversation_public.py ADDED
@@ -0,0 +1,506 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ from enum import auto, Enum
3
+ from typing import List, Tuple
4
+ import base64
5
+ from io import BytesIO
6
+ from PIL import Image
7
+
8
+ import base64
9
+ tts_format = "Please synthesize the speech corresponding to the follwing text.\n"
10
+
11
+ class SeparatorStyle(Enum):
12
+ """Different separator style."""
13
+ SINGLE = auto()
14
+ TWO = auto()
15
+ MPT = auto()
16
+ PLAIN = auto()
17
+ LLAMA_2 = auto()
18
+ GLM4 = auto()
19
+
20
+
21
+ @dataclasses.dataclass
22
+ class Conversation:
23
+ """A class that keeps all conversation history."""
24
+ system: str
25
+ roles: List[str]
26
+ messages: List[List[str]]
27
+ offset: int
28
+ sep_style: SeparatorStyle = SeparatorStyle.SINGLE
29
+ sep: str = "###"
30
+ sep2: str = None
31
+ version: str = "Unknown"
32
+
33
+ skip_next: bool = False
34
+
35
+ def get_prompt(self):
36
+ messages = self.messages
37
+ if len(messages) > 0 and type(messages[0][1]) is tuple and messages[0][1][1] is not None:
38
+ messages = self.messages.copy()
39
+ init_role, init_msg = messages[0].copy()
40
+ init_msg = init_msg[0].replace("<image>", "").strip()
41
+ if 'mmtag' in self.version:
42
+ messages[0] = (init_role, init_msg)
43
+ messages.insert(0, (self.roles[0], "<Image><image></Image>"))
44
+ messages.insert(1, (self.roles[1], "Received."))
45
+ else:
46
+ messages[0] = (init_role, "<image>\n" + init_msg)
47
+
48
+ if self.sep_style == SeparatorStyle.SINGLE:
49
+ ret = self.system + self.sep
50
+ for role, message in messages:
51
+ if message:
52
+ if type(message) is tuple:
53
+ message, _, _ = message[:3]
54
+ ret += role + ": " + message + self.sep
55
+ else:
56
+ ret += role + ":"
57
+ elif self.sep_style == SeparatorStyle.TWO:
58
+ seps = [self.sep, self.sep2]
59
+ ret = self.system + seps[0]
60
+ for i, (role, message) in enumerate(messages):
61
+ if message:
62
+ if type(message) is tuple:
63
+ message, _, _ = message[:3]
64
+ ret += role + ": " + message + seps[i % 2]
65
+ else:
66
+ ret += role + ":"
67
+ elif self.sep_style == SeparatorStyle.MPT:
68
+ ret = self.system + self.sep
69
+ for role, message in messages:
70
+ if message:
71
+ if type(message) is tuple:
72
+ message, _, _ = message[:3]
73
+ ret += role + message + self.sep
74
+ else:
75
+ ret += role
76
+ elif self.sep_style == SeparatorStyle.LLAMA_2:
77
+ wrap_sys = lambda msg: f"<<SYS>>\n{msg}\n<</SYS>>\n\n" if len(msg) > 0 else msg
78
+ wrap_inst = lambda msg: f"[INST] {msg} [/INST]"
79
+ ret = ""
80
+
81
+ for i, (role, message) in enumerate(messages):
82
+ if i == 0:
83
+ assert message, "first message should not be none"
84
+ assert role == self.roles[0], "first message should come from user"
85
+ if message:
86
+ if type(message) is tuple:
87
+ message, _, _ = message[:3]
88
+ if i == 0: message = wrap_sys(self.system) + message
89
+ if i % 2 == 0:
90
+ message = wrap_inst(message)
91
+ ret += self.sep + message
92
+ else:
93
+ ret += " " + message + " " + self.sep2
94
+ else:
95
+ ret += ""
96
+ ret = ret.lstrip(self.sep)
97
+ elif self.sep_style == SeparatorStyle.PLAIN:
98
+ seps = [self.sep, self.sep2]
99
+ ret = self.system
100
+ for i, (role, message) in enumerate(messages):
101
+ if message:
102
+ if type(message) is tuple:
103
+ message, _, _ = message[:3]
104
+ ret += message + seps[i % 2]
105
+ else:
106
+ ret += ""
107
+ elif self.sep_style == SeparatorStyle.GLM4:
108
+ role = ("<|user|>", "<|assistant|>")
109
+ ret = self.system + role[0]
110
+ for i, (role, message) in enumerate(messages):
111
+ if message:
112
+ if type(message) is tuple:
113
+ message, _, _ = message[:3]
114
+ ret += self.sep + message + role[(i+1) % 2]
115
+ else:
116
+ ret += ""
117
+ else:
118
+ raise ValueError(f"Invalid style: {self.sep_style}")
119
+
120
+ return ret
121
+
122
+ def append_message(self, role, message):
123
+ if isinstance(self.messages, tuple):
124
+ self.messages += ([role, message],)
125
+ else:
126
+ self.messages.append([role, message])
127
+
128
+ def process_image(self, image, image_process_mode, return_pil=False, image_format='PNG', max_len=1344, min_len=672):
129
+ if image_process_mode == "Pad":
130
+ def expand2square(pil_img, background_color=(122, 116, 104)):
131
+ width, height = pil_img.size
132
+ if width == height:
133
+ return pil_img
134
+ elif width > height:
135
+ result = Image.new(pil_img.mode, (width, width), background_color)
136
+ result.paste(pil_img, (0, (width - height) // 2))
137
+ return result
138
+ else:
139
+ result = Image.new(pil_img.mode, (height, height), background_color)
140
+ result.paste(pil_img, ((height - width) // 2, 0))
141
+ return result
142
+ image = expand2square(image)
143
+ elif image_process_mode in ["Default", "Crop"]:
144
+ pass
145
+ elif image_process_mode == "Resize":
146
+ image = image.resize((336, 336))
147
+ else:
148
+ raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
149
+ if max(image.size) > max_len:
150
+ max_hw, min_hw = max(image.size), min(image.size)
151
+ aspect_ratio = max_hw / min_hw
152
+ shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
153
+ longest_edge = int(shortest_edge * aspect_ratio)
154
+ W, H = image.size
155
+ if H > W:
156
+ H, W = longest_edge, shortest_edge
157
+ else:
158
+ H, W = shortest_edge, longest_edge
159
+ image = image.resize((W, H))
160
+ if return_pil:
161
+ return image
162
+ else:
163
+ buffered = BytesIO()
164
+ image.save(buffered, format=image_format)
165
+ img_b64_str = base64.b64encode(buffered.getvalue()).decode()
166
+ return img_b64_str
167
+
168
+ def get_images(self, return_pil=False):
169
+ images = []
170
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
171
+ if i % 2 == 0:
172
+ if type(msg) is tuple and msg[1] is not None:
173
+ msg, image, image_process_mode = msg[:3]
174
+ image = self.process_image(image, image_process_mode, return_pil=return_pil)
175
+ images.append(image)
176
+ return images
177
+
178
+ def to_gradio_chatbot(self):
179
+ ret = []
180
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
181
+ if i % 2 == 0:
182
+ if type(msg) is tuple:
183
+ msg, image, image_process_mode = msg
184
+ img_b64_str = self.process_image(
185
+ image, "Default", return_pil=False,
186
+ image_format='JPEG')
187
+ img_str = f'<img src="data:image/jpeg;base64,{img_b64_str}" alt="user upload image" />'
188
+ msg = img_str + msg.replace('<image>', '').strip()
189
+ ret.append([msg, None])
190
+ else:
191
+ ret.append([msg, None])
192
+ else:
193
+ ret[-1][-1] = msg
194
+ return ret
195
+
196
+ def to_gradio_chatbot_public(self):
197
+ ret = []
198
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
199
+ if i % 2 == 0:
200
+ if type(msg) is tuple:
201
+ msg, image, image_process_mode, audio_input = msg
202
+ ret_msg = ""
203
+ if image is not None:
204
+ img_b64_str = self.process_image(
205
+ image, "Default", return_pil=False,
206
+ image_format='JPEG')
207
+ img_str = f'<img src="data:image/jpeg;base64,{img_b64_str}" alt="user upload image" />'
208
+ ret_msg += img_str
209
+ if audio_input is not None:
210
+ audio_b64_str = base64.b64encode(open(audio_input, "rb").read()).decode("utf-8")
211
+ audio_str = f'<audio src="data:audio/wav;base64,{audio_b64_str}" controls ></audio>'
212
+ ret_msg += audio_str
213
+ else:
214
+ ret_msg += msg.replace('<image>', '').replace(tts_format, '').strip()
215
+ ret.append([ret_msg, None])
216
+ else:
217
+ ret.append([msg, None])
218
+ else:
219
+ if type(msg) is tuple:
220
+ audio_b64_str = base64.b64encode(open(msg[1], "rb").read()).decode("utf-8")
221
+ msg = f'<audio src="data:audio/wav;base64,{audio_b64_str}" controls autoplay></audio>'
222
+ ret[-1][-1] = msg
223
+ return ret
224
+
225
+ def copy(self):
226
+ return Conversation(
227
+ system=self.system,
228
+ roles=self.roles,
229
+ messages=[[x, y] for x, y in self.messages],
230
+ offset=self.offset,
231
+ sep_style=self.sep_style,
232
+ sep=self.sep,
233
+ sep2=self.sep2,
234
+ version=self.version)
235
+
236
+ def dict(self):
237
+ if len(self.get_images()) > 0:
238
+ return {
239
+ "system": self.system,
240
+ "roles": self.roles,
241
+ "messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
242
+ "offset": self.offset,
243
+ "sep": self.sep,
244
+ "sep2": self.sep2,
245
+ }
246
+ return {
247
+ "system": self.system,
248
+ "roles": self.roles,
249
+ "messages": self.messages,
250
+ "offset": self.offset,
251
+ "sep": self.sep,
252
+ "sep2": self.sep2,
253
+ }
254
+
255
+
256
+ conv_vicuna_v0 = Conversation(
257
+ system="A chat between a curious human and an artificial intelligence assistant. "
258
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
259
+ roles=("Human", "Assistant"),
260
+ messages=(
261
+ ("Human", "What are the key differences between renewable and non-renewable energy sources?"),
262
+ ("Assistant",
263
+ "Renewable energy sources are those that can be replenished naturally in a relatively "
264
+ "short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
265
+ "Non-renewable energy sources, on the other hand, are finite and will eventually be "
266
+ "depleted, such as coal, oil, and natural gas. Here are some key differences between "
267
+ "renewable and non-renewable energy sources:\n"
268
+ "1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
269
+ "energy sources are finite and will eventually run out.\n"
270
+ "2. Environmental impact: Renewable energy sources have a much lower environmental impact "
271
+ "than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
272
+ "and other negative effects.\n"
273
+ "3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
274
+ "have lower operational costs than non-renewable sources.\n"
275
+ "4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
276
+ "locations than non-renewable sources.\n"
277
+ "5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
278
+ "situations and needs, while non-renewable sources are more rigid and inflexible.\n"
279
+ "6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
280
+ "non-renewable sources are not, and their depletion can lead to economic and social instability.\n")
281
+ ),
282
+ offset=2,
283
+ sep_style=SeparatorStyle.SINGLE,
284
+ sep="###",
285
+ )
286
+
287
+ conv_vicuna_v1 = Conversation(
288
+ system="A chat between a curious user and an artificial intelligence assistant. "
289
+ "The assistant gives helpful, detailed, and polite answers to the user's questions.",
290
+ roles=("USER", "ASSISTANT"),
291
+ version="v1",
292
+ messages=(),
293
+ offset=0,
294
+ sep_style=SeparatorStyle.TWO,
295
+ sep=" ",
296
+ sep2="</s>",
297
+ )
298
+
299
+ conv_llama_2 = Conversation(
300
+ system="""You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
301
+
302
+ If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""",
303
+ roles=("USER", "ASSISTANT"),
304
+ version="llama_v2",
305
+ messages=(),
306
+ offset=0,
307
+ sep_style=SeparatorStyle.LLAMA_2,
308
+ sep="<s>",
309
+ sep2="</s>",
310
+ )
311
+
312
+ conv_llava_llama_2 = Conversation(
313
+ system="You are a helpful language and vision assistant. "
314
+ "You are able to understand the visual content that the user provides, "
315
+ "and assist the user with a variety of tasks using natural language.",
316
+ roles=("USER", "ASSISTANT"),
317
+ version="llama_v2",
318
+ messages=(),
319
+ offset=0,
320
+ sep_style=SeparatorStyle.LLAMA_2,
321
+ sep="<s>",
322
+ sep2="</s>",
323
+ )
324
+
325
+ conv_mpt = Conversation(
326
+ system="""<|im_start|>system
327
+ A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.""",
328
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
329
+ version="mpt",
330
+ messages=(),
331
+ offset=0,
332
+ sep_style=SeparatorStyle.MPT,
333
+ sep="<|im_end|>",
334
+ )
335
+
336
+ conv_llava_plain = Conversation(
337
+ system="",
338
+ roles=("", ""),
339
+ messages=(),
340
+ offset=0,
341
+ sep_style=SeparatorStyle.PLAIN,
342
+ sep="\n",
343
+ )
344
+
345
+ conv_llava_v0 = Conversation(
346
+ system="A chat between a curious human and an artificial intelligence assistant. "
347
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
348
+ roles=("Human", "Assistant"),
349
+ messages=(),
350
+ offset=0,
351
+ sep_style=SeparatorStyle.SINGLE,
352
+ sep="###",
353
+ )
354
+
355
+ conv_llava_v0_mmtag = Conversation(
356
+ system="A chat between a curious user and an artificial intelligence assistant. "
357
+ "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
358
+ "The visual content will be provided with the following format: <Image>visual content</Image>.",
359
+ roles=("Human", "Assistant"),
360
+ messages=(
361
+ ),
362
+ offset=0,
363
+ sep_style=SeparatorStyle.SINGLE,
364
+ sep="###",
365
+ version="v0_mmtag",
366
+ )
367
+
368
+ conv_llava_v1 = Conversation(
369
+ system="A chat between a curious human and an artificial intelligence assistant. "
370
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
371
+ roles=("USER", "ASSISTANT"),
372
+ version="v1",
373
+ messages=(),
374
+ offset=0,
375
+ sep_style=SeparatorStyle.TWO,
376
+ sep=" ",
377
+ sep2="</s>",
378
+ )
379
+
380
+ conv_llava_v1_mmtag = Conversation(
381
+ system="A chat between a curious user and an artificial intelligence assistant. "
382
+ "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
383
+ "The visual content will be provided with the following format: <Image>visual content</Image>.",
384
+ roles=("USER", "ASSISTANT"),
385
+ messages=(),
386
+ offset=0,
387
+ sep_style=SeparatorStyle.TWO,
388
+ sep=" ",
389
+ sep2="</s>",
390
+ version="v1_mmtag",
391
+ )
392
+
393
+ conv_mistral_instruct = Conversation(
394
+ system="",
395
+ roles=("USER", "ASSISTANT"),
396
+ version="llama_v2",
397
+ messages=(),
398
+ offset=0,
399
+ sep_style=SeparatorStyle.LLAMA_2,
400
+ sep="",
401
+ sep2="</s>",
402
+ )
403
+
404
+ conv_chatml_direct = Conversation(
405
+ system="""<|im_start|>system
406
+ Answer the questions.""",
407
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
408
+ version="mpt",
409
+ messages=(),
410
+ offset=0,
411
+ sep_style=SeparatorStyle.MPT,
412
+ sep="<|im_end|>",
413
+ )
414
+
415
+ conv_llama3 = Conversation(
416
+ system="""<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nA chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.""",
417
+ roles=("<|start_header_id|>user<|end_header_id|>\n\n", "<|start_header_id|>assistant<|end_header_id|>\n\n"),
418
+ version="llama3",
419
+ messages=(),
420
+ offset=0,
421
+ sep_style=SeparatorStyle.MPT,
422
+ sep="<|eot_id|>",
423
+ )
424
+
425
+ conv_llama3_demo = Conversation(
426
+ system="""<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nA chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. Your name is emova, and you are purely developed by the emova Team.""",
427
+ roles=("<|start_header_id|>user<|end_header_id|>\n\n", "<|start_header_id|>assistant<|end_header_id|>\n\n"),
428
+ version="llama3_demo",
429
+ messages=(),
430
+ offset=0,
431
+ sep_style=SeparatorStyle.MPT,
432
+ sep="<|eot_id|>",
433
+ )
434
+
435
+ conv_llama3_without_system = Conversation(
436
+ system="",
437
+ roles=("<|start_header_id|>user<|end_header_id|>\n\n", "<|start_header_id|>assistant<|end_header_id|>\n\n"),
438
+ version="llama3_without_system",
439
+ messages=(),
440
+ offset=0,
441
+ sep_style=SeparatorStyle.MPT,
442
+ sep="<|eot_id|>",
443
+ )
444
+
445
+ conv_llama3_without_systemV2 = Conversation(
446
+ system="",
447
+ roles=("user:", "assistant:"),
448
+ version="llama3_without_systemv2",
449
+ messages=(),
450
+ offset=0,
451
+ sep_style=SeparatorStyle.MPT,
452
+ sep="\n\n",
453
+ )
454
+
455
+ conv_qwen2 = Conversation(
456
+ system='<|im_start|>system\nYou are a helpful assistant.',
457
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
458
+ version="qwen2",
459
+ messages=(),
460
+ offset=0,
461
+ sep_style=SeparatorStyle.MPT,
462
+ sep="<|im_end|>\n",
463
+ )
464
+
465
+ conv_glm4 = Conversation(
466
+ system='[gMASK]<sop>',
467
+ roles=("<|user|>", "<|assistant|>"),
468
+ version="glm4",
469
+ messages=(),
470
+ offset=0,
471
+ sep_style=SeparatorStyle.GLM4,
472
+ sep="\n",
473
+ )
474
+
475
+
476
+ default_conversation = conv_vicuna_v1
477
+ conv_templates = {
478
+ "default": conv_vicuna_v0,
479
+ "v0": conv_vicuna_v0,
480
+ "v1": conv_vicuna_v1,
481
+ "vicuna_v1": conv_vicuna_v1,
482
+ "llama_2": conv_llama_2,
483
+ "mistral_instruct": conv_mistral_instruct,
484
+ "chatml_direct": conv_chatml_direct,
485
+ "mistral_direct": conv_chatml_direct,
486
+
487
+ "plain": conv_llava_plain,
488
+ "v0_plain": conv_llava_plain,
489
+ "llava_v0": conv_llava_v0,
490
+ "v0_mmtag": conv_llava_v0_mmtag,
491
+ "llava_v1": conv_llava_v1,
492
+ "v1_mmtag": conv_llava_v1_mmtag,
493
+ "llava_llama_2": conv_llava_llama_2,
494
+ "llama3": conv_llama3,
495
+ "llama3_demo": conv_llama3_demo,
496
+ "llama3_without_system": conv_llama3_without_system,
497
+ "conv_llama3_without_systemV2": conv_llama3_without_systemV2,
498
+
499
+ "mpt": conv_mpt,
500
+ "qwen2": conv_qwen2,
501
+ "glm4": conv_glm4,
502
+ }
503
+
504
+
505
+ if __name__ == "__main__":
506
+ print(default_conversation.get_prompt())
requirements.txt ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ omegaconf
2
+ torch==2.0.1
3
+ torchvision==0.15.2
4
+ transformers==4.44.0
5
+ sentencepiece==0.1.99
6
+ accelerate==0.33.0
7
+ einops==0.6.1
8
+ einops-exts==0.0.4
9
+ timm==0.6.13
10
+ scipy
11
+ gradio
12
+
13
+ monotonic_align
14
+ librosa==0.8.0
15
+ phonemizer
16
+ unidecode
17
+ hydra-core==1.3.2
18
+ pytorch_lightning==1.1.0
19
+ wget
20
+ wrapt
21
+ onnx
22
+ frozendict
23
+ inflect
24
+ braceexpand
25
+ webdataset
26
+ torch_stft
27
+ sox
28
+ editdistance
29
+ numpy==1.23.5