lmzjms commited on
Commit
85c78a5
·
1 Parent(s): ce764e5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +176 -213
app.py CHANGED
@@ -1,146 +1,3 @@
1
- # import torch
2
- # import numpy as np
3
- # import gradio as gr
4
- # from PIL import Image
5
- # from omegaconf import OmegaConf
6
- # from pathlib import Path
7
- # from vocoder.bigvgan.models import VocoderBigVGAN
8
- # from ldm.models.diffusion.ddim import DDIMSampler
9
- # from ldm.util import instantiate_from_config
10
- # from wav_evaluation.models.CLAPWrapper import CLAPWrapper
11
-
12
- # SAMPLE_RATE = 16000
13
-
14
- # torch.set_grad_enabled(False)
15
- # device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
16
-
17
-
18
- # def initialize_model(config, ckpt):
19
- # config = OmegaConf.load(config)
20
- # model = instantiate_from_config(config.model)
21
- # model.load_state_dict(torch.load(ckpt,map_location='cpu')["state_dict"], strict=False)
22
-
23
- # model = model.to(device)
24
- # model.cond_stage_model.to(model.device)
25
- # model.cond_stage_model.device = model.device
26
- # print(model.device,device,model.cond_stage_model.device)
27
- # sampler = DDIMSampler(model)
28
-
29
- # return sampler
30
-
31
- # sampler = initialize_model('configs/text_to_audio/txt2audio_args.yaml', 'useful_ckpts/ta40multi_epoch=000085.ckpt')
32
- # vocoder = VocoderBigVGAN('vocoder/logs/bigv16k53w',device=device)
33
- # clap_model = CLAPWrapper('useful_ckpts/CLAP/CLAP_weights_2022.pth','useful_ckpts/CLAP/config.yml',use_cuda=torch.cuda.is_available())
34
-
35
- # def select_best_audio(prompt,wav_list):
36
- # text_embeddings = clap_model.get_text_embeddings([prompt])
37
- # score_list = []
38
- # for data in wav_list:
39
- # sr,wav = data
40
- # audio_embeddings = clap_model.get_audio_embeddings([(torch.FloatTensor(wav),sr)], resample=True)
41
- # score = clap_model.compute_similarity(audio_embeddings, text_embeddings,use_logit_scale=False).squeeze().cpu().numpy()
42
- # score_list.append(score)
43
- # max_index = np.array(score_list).argmax()
44
- # print(score_list,max_index)
45
- # return wav_list[max_index]
46
-
47
- # def txt2audio(sampler,vocoder,prompt, seed, scale, ddim_steps, n_samples=1, W=624, H=80):
48
- # prng = np.random.RandomState(seed)
49
- # start_code = prng.randn(n_samples, sampler.model.first_stage_model.embed_dim, H // 8, W // 8)
50
- # start_code = torch.from_numpy(start_code).to(device=device, dtype=torch.float32)
51
-
52
- # uc = None
53
- # if scale != 1.0:
54
- # uc = sampler.model.get_learned_conditioning(n_samples * [""])
55
- # c = sampler.model.get_learned_conditioning(n_samples * [prompt])# shape:[1,77,1280],即还没有变成句子embedding,仍是每个单词的embedding
56
- # shape = [sampler.model.first_stage_model.embed_dim, H//8, W//8] # (z_dim, 80//2^x, 848//2^x)
57
- # samples_ddim, _ = sampler.sample(S=ddim_steps,
58
- # conditioning=c,
59
- # batch_size=n_samples,
60
- # shape=shape,
61
- # verbose=False,
62
- # unconditional_guidance_scale=scale,
63
- # unconditional_conditioning=uc,
64
- # x_T=start_code)
65
-
66
- # x_samples_ddim = sampler.model.decode_first_stage(samples_ddim)
67
- # x_samples_ddim = torch.clamp((x_samples_ddim+1.0)/2.0, min=0.0, max=1.0) # [0, 1]
68
-
69
- # wav_list = []
70
- # for idx,spec in enumerate(x_samples_ddim):
71
- # wav = vocoder.vocode(spec)
72
- # wav_list.append((SAMPLE_RATE,wav))
73
- # best_wav = select_best_audio(prompt,wav_list)
74
- # return best_wav
75
-
76
-
77
- # def predict(prompt, ddim_steps, num_samples, scale, seed):# 经过试验,这个input_image需要是256x256、512x512的大小效果才正常,实际应该resize一下,输出再resize回去,但是他们使用的是pad,不知道为什么
78
- # melbins,mel_len = 80,624
79
- # with torch.no_grad():
80
- # result = txt2audio(
81
- # sampler=sampler,
82
- # vocoder=vocoder,
83
- # prompt=prompt,
84
- # seed=seed,
85
- # scale=scale,
86
- # ddim_steps=ddim_steps,
87
- # n_samples=num_samples,
88
- # H=melbins, W=mel_len
89
- # )
90
-
91
- # return result
92
-
93
-
94
- # with gr.Blocks() as demo:
95
- # with gr.Row():
96
- # gr.Markdown("## Make-An-Audio: Text-to-Audio Generation")
97
-
98
- # with gr.Row():
99
- # with gr.Column():
100
- # prompt = gr.Textbox(label="Prompt: Input your text here. ")
101
- # run_button = gr.Button(label="Run")
102
-
103
-
104
- # with gr.Accordion("Advanced options", open=False):
105
- # num_samples = gr.Slider(
106
- # label="Select from audios num.This number control the number of candidates \
107
- # (e.g., generate three audios and choose the best to show you). A Larger value usually lead to \
108
- # better quality with heavier computation", minimum=1, maximum=10, value=3, step=1)
109
- # # num_samples = 1
110
- # ddim_steps = gr.Slider(label="Steps", minimum=1,
111
- # maximum=150, value=100, step=1)
112
- # scale = gr.Slider(
113
- # label="Guidance Scale:(Large => more relevant to text but the quality may drop)", minimum=0.1, maximum=4.0, value=1.5, step=0.1
114
- # )
115
- # seed = gr.Slider(
116
- # label="Seed:Change this value (any integer number) will lead to a different generation result.",
117
- # minimum=0,
118
- # maximum=2147483647,
119
- # step=1,
120
- # value=44,
121
- # )
122
-
123
- # with gr.Column():
124
- # # audio_list = []
125
- # # for i in range(int(num_samples)):
126
- # # audio_list.append(gr.outputs.Audio())
127
- # outaudio = gr.Audio()
128
-
129
-
130
- # run_button.click(fn=predict, inputs=[
131
- # prompt,ddim_steps, num_samples, scale, seed], outputs=[outaudio])# inputs的参数只能传gr.xxx
132
- # with gr.Row():
133
- # with gr.Column():
134
- # gr.Examples(
135
- # examples = [['a dog barking and a bird chirping',100,3,1.5,55],['fireworks pop and explode',100,3,1.5,55],
136
- # ['piano and violin plays',100,3,1.5,55],['wind thunder and rain falling',100,3,1.5,55],['music made by drum kit',100,3,1.5,55]],
137
- # inputs = [prompt,ddim_steps, num_samples, scale, seed],
138
- # outputs = [outaudio]
139
- # )
140
- # with gr.Column():
141
- # pass
142
-
143
- # demo.launch()
144
  from langchain.agents.initialize import initialize_agent
145
  from langchain.agents.tools import Tool
146
  from langchain.chains.conversation.memory import ConversationBufferMemory
@@ -149,7 +6,7 @@ from audio_foundation_models import *
149
  import gradio as gr
150
 
151
  _DESCRIPTION = '# [AudioGPT](https://github.com/AIGC-Audio/AudioGPT)'
152
- _DESCRIPTION += '\n<p>This is a demo to the work <a href="https://github.com/AIGC-Audio/AudioGPT" style="text-decoration: underline;" target="_blank">AudioGPT: Understanding and Generating Speech, Music, Sound, and Talking Head</a>. </p>'
153
  _DESCRIPTION += '\n<p>This model can only be used for non-commercial purposes.'
154
  if (SPACE_ID := os.getenv('SPACE_ID')) is not None:
155
  _DESCRIPTION += f'\n<p>For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings. <a href="https://huggingface.co/spaces/{SPACE_ID}?duplicate=true"><img style="display: inline; margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space" /></a></p>'
@@ -186,19 +43,23 @@ Previous conversation history:
186
  New input: {input}
187
  Thought: Do I need to use a tool? {agent_scratchpad}"""
188
 
189
- def cut_dialogue_history(history_memory, keep_last_n_words = 500):
 
 
 
190
  tokens = history_memory.split()
191
  n_tokens = len(tokens)
192
  print(f"history_memory:{history_memory}, n_tokens: {n_tokens}")
193
  if n_tokens < keep_last_n_words:
194
  return history_memory
195
- else:
196
- paragraphs = history_memory.split('\n')
197
- last_n_tokens = n_tokens
198
- while last_n_tokens >= keep_last_n_words:
199
- last_n_tokens = last_n_tokens - len(paragraphs[0].split(' '))
200
- paragraphs = paragraphs[1:]
201
- return '\n' + '\n'.join(paragraphs)
 
202
 
203
  class ConversationBot:
204
  def __init__(self, load_dict):
@@ -208,11 +69,6 @@ class ConversationBot:
208
  self.models = dict()
209
  for class_name, device in load_dict.items():
210
  self.models[class_name] = globals()[class_name](device=device)
211
- for class_name, instance in self.models.items():
212
- for e in dir(instance):
213
- if e.startswith('inference'):
214
- func = getattr(instance, e)
215
- self.tools.append(Tool(name=func.name, description=func.description, func=func))
216
 
217
  def run_text(self, text, state):
218
  print("===============Running run_text =============")
@@ -225,7 +81,7 @@ class ConversationBot:
225
  response = res['output']
226
  state = state + [(text, response)]
227
  print("Outputs:", state)
228
- return state, state, gr.Audio.update(visible=False), gr.Image.update(visible=False), gr.Button.update(visible=False)
229
  else:
230
  tool = res['intermediate_steps'][0][0].tool
231
  if tool == "Generate Image From User Input Text":
@@ -234,14 +90,14 @@ class ConversationBot:
234
  state = state + [(text, response)]
235
  print(f"\nProcessed run_text, Input text: {text}\nCurrent state: {state}\n"
236
  f"Current Memory: {self.agent.memory.buffer}")
237
- return state, state, gr.Audio.update(visible=False), gr.Image.update(visible=False), gr.Button.update(visible=False)
238
  elif tool == "Detect The Sound Event From The Audio":
239
  image_filename = res['intermediate_steps'][0][1]
240
  response = res['output'] + f"![](/file={image_filename})*{image_filename}*"
241
  state = state + [(text, response)]
242
  print(f"\nProcessed run_text, Input text: {text}\nCurrent state: {state}\n"
243
  f"Current Memory: {self.agent.memory.buffer}")
244
- return state, state, gr.Audio.update(visible=False), gr.Image.update(visible=False), gr.Button.update(visible=False)
245
  elif tool == "Generate Text From The Audio" or tool == "Transcribe speech" or tool == "Target Sound Detection":
246
  print("======>Current memory:\n %s" % self.agent.memory)
247
  response = re.sub('(image/\S*png)', lambda m: f'![](/file={m.group(0)})*{m.group(0)}*', res['output'])
@@ -249,22 +105,21 @@ class ConversationBot:
249
  #response = res['output'] + f"![](/file={image_filename})*{image_filename}*"
250
  state = state + [(text, response)]
251
  print("Outputs:", state)
252
- return state, state, gr.Audio.update(visible=False), gr.Image.update(visible=False), gr.Button.update(visible=False)
253
  elif tool == "Audio Inpainting":
254
  audio_filename = res['intermediate_steps'][0][0].tool_input
255
  image_filename = res['intermediate_steps'][0][1]
256
  print("======>Current memory:\n %s" % self.agent.memory)
257
- print(res)
258
  response = res['output']
259
  state = state + [(text, response)]
260
  print("Outputs:", state)
261
- return state, state, gr.Audio.update(value=audio_filename,visible=True), gr.Image.update(value=image_filename,visible=True), gr.Button.update(visible=True)
262
  print("======>Current memory:\n %s" % self.agent.memory)
263
  response = re.sub('(image/\S*png)', lambda m: f'![](/file={m.group(0)})*{m.group(0)}*', res['output'])
264
  audio_filename = res['intermediate_steps'][0][1]
265
  state = state + [(text, response)]
266
  print("Outputs:", state)
267
- return state, state, gr.Audio.update(value=audio_filename,visible=True), gr.Image.update(visible=False), gr.Button.update(visible=False)
268
 
269
  def run_image_or_audio(self, file, state, txt):
270
  file_type = file.name[-3:]
@@ -273,8 +128,9 @@ class ConversationBot:
273
  print("Inputs:", file, state)
274
  print("======>Previous memory:\n %s" % self.agent.memory)
275
  audio_filename = os.path.join('audio', str(uuid.uuid4())[0:8] + ".wav")
276
- audio_load = whisper.load_audio(file.name)
277
- soundfile.write(audio_filename, audio_load, samplerate = 16000)
 
278
  description = self.models['A2T'].inference(audio_filename)
279
  Human_prompt = "\nHuman: provide an audio named {}. The description is: {}. This information helps you to understand this audio, but you should use tools to finish following tasks, " \
280
  "rather than directly imagine from my description. If you understand, say \"Received\". \n".format(audio_filename, description)
@@ -286,7 +142,7 @@ class ConversationBot:
286
  #state = state + [(f"<audio src=audio_filename controls=controls></audio>*{audio_filename}*", AI_prompt)]
287
  state = state + [(f"*{audio_filename}*", AI_prompt)]
288
  print("Outputs:", state)
289
- return state, state, txt + ' ' + audio_filename + ' ', gr.Audio.update(value=audio_filename,visible=True)
290
  else:
291
  # print("===============Running run_image =============")
292
  # print("Inputs:", file, state)
@@ -312,13 +168,69 @@ class ConversationBot:
312
  state = state + [(f"![](/file={image_filename})*{image_filename}*", AI_prompt)]
313
  print(f"\nProcessed run_image, Input image: {image_filename}\nCurrent state: {state}\n"
314
  f"Current Memory: {self.agent.memory.buffer}")
315
- return state, state, txt + f'{txt} {image_filename} ', gr.Audio.update(visible=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
316
 
317
  def inpainting(self, state, audio_filename, image_filename):
318
  print("===============Running inpainting =============")
319
  print("Inputs:", state)
320
  print("======>Previous memory:\n %s" % self.agent.memory)
321
- # inpaint = Inpaint(device="cpu")
322
  new_image_filename, new_audio_filename = self.models['Inpaint'].predict(audio_filename, image_filename)
323
  AI_prompt = "Here are the predict audio and the mel spectrum." + f"*{new_audio_filename}*" + f"![](/file={new_image_filename})*{new_image_filename}*"
324
  self.agent.memory.buffer = self.agent.memory.buffer + 'AI: ' + AI_prompt
@@ -328,33 +240,62 @@ class ConversationBot:
328
  return state, state, gr.Image.update(visible=False), gr.Audio.update(value=new_audio_filename, visible=True), gr.Button.update(visible=False)
329
  def clear_audio(self):
330
  return gr.Audio.update(value=None, visible=False)
 
 
331
  def clear_image(self):
332
  return gr.Image.update(value=None, visible=False)
 
 
333
  def clear_button(self):
334
  return gr.Button.update(visible=False)
335
- def init_agent(self, openai_api_key):
336
- self.llm = OpenAI(temperature=0, openai_api_key=openai_api_key)
337
- self.agent = initialize_agent(
338
- self.tools,
339
- self.llm,
340
- agent="conversational-react-description",
341
- verbose=True,
342
- memory=self.memory,
343
- return_intermediate_steps=True,
344
- agent_kwargs={'prefix': AUDIO_CHATGPT_PREFIX, 'format_instructions': AUDIO_CHATGPT_FORMAT_INSTRUCTIONS, 'suffix': AUDIO_CHATGPT_SUFFIX}, )
345
- return gr.update(visible = True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
346
 
347
 
348
 
349
  if __name__ == '__main__':
350
  bot = ConversationBot({'ImageCaptioning': 'cuda:0',
351
- # 'T2A': 'cuda:0',
352
- # 'I2A': 'cuda:0',
353
  'TTS': 'cpu',
354
  'T2S': 'cpu',
355
  'ASR': 'cuda:0',
356
  'A2T': 'cpu',
357
- # 'Inpaint': 'cuda:0',
358
  'SoundDetection': 'cpu',
359
  'Binaural': 'cuda:0',
360
  'SoundExtraction': 'cuda:0',
@@ -362,37 +303,50 @@ if __name__ == '__main__':
362
  'Speech_Enh_SC': 'cuda:0',
363
  'Speech_SS': 'cuda:0'
364
  })
365
- with gr.Blocks(css="#chatbot {overflow:auto; height:500px;}") as demo:
366
- gr.Markdown(_DESCRIPTION)
367
-
368
  with gr.Row():
 
 
 
 
 
 
 
369
  openai_api_key_textbox = gr.Textbox(
370
- placeholder="Paste your OpenAI API key here to start AudioGPT(sk-...) and press Enter ↵️",
371
  show_label=False,
372
  lines=1,
373
  type="password",
374
  )
375
-
376
- chatbot = gr.Chatbot(elem_id="chatbot", label="AudioGPT")
377
- state = gr.State([])
378
- with gr.Row(visible = False) as input_raws:
379
  with gr.Column(scale=0.7):
380
  txt = gr.Textbox(show_label=False, placeholder="Enter text and press enter, or upload an image").style(container=False)
381
  with gr.Column(scale=0.1, min_width=0):
382
  run = gr.Button("🏃‍♂️Run")
383
  with gr.Column(scale=0.1, min_width=0):
384
- clear = gr.Button("🔄Clear️")
385
  with gr.Column(scale=0.1, min_width=0):
386
  btn = gr.UploadButton("🖼️/🎙️ Upload", file_types=["image","audio"])
387
- with gr.Row():
388
- with gr.Column():
389
- outaudio = gr.Audio(visible=False)
390
- with gr.Row():
391
- with gr.Column():
392
- show_mel = gr.Image(type="filepath",tool='sketch',visible=False)
393
- with gr.Row():
394
- with gr.Column():
395
- run_button = gr.Button("Predict Masked Place",visible=False)
 
 
 
 
 
 
 
 
 
 
 
396
  gr.Examples(
397
  examples=["Generate a speech with text 'here we go'",
398
  "Transcribe this speech",
@@ -409,18 +363,27 @@ if __name__ == '__main__':
409
  inputs=txt
410
  )
411
 
412
- openai_api_key_textbox.submit(bot.init_agent, [openai_api_key_textbox], [input_raws])
413
- txt.submit(bot.run_text, [txt, state], [chatbot, state, outaudio, show_mel, run_button])
 
414
  txt.submit(lambda: "", None, txt)
415
- run.click(bot.run_text, [txt, state], [chatbot, state, outaudio, show_mel, run_button])
416
  run.click(lambda: "", None, txt)
417
- btn.upload(bot.run_image_or_audio, [btn, state, txt], [chatbot, state, txt, outaudio])
418
- run_button.click(bot.inpainting, [state, outaudio, show_mel], [chatbot, state, show_mel, outaudio, run_button])
419
- clear.click(bot.memory.clear)
420
- clear.click(lambda: [], None, chatbot)
421
- clear.click(lambda: [], None, state)
422
- clear.click(lambda:None, None, txt)
423
- clear.click(bot.clear_button, None, run_button)
424
- clear.click(bot.clear_image, None, show_mel)
425
- clear.click(bot.clear_audio, None, outaudio)
 
 
 
 
 
 
 
 
426
  demo.launch(server_name="0.0.0.0", server_port=7860)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from langchain.agents.initialize import initialize_agent
2
  from langchain.agents.tools import Tool
3
  from langchain.chains.conversation.memory import ConversationBufferMemory
 
6
  import gradio as gr
7
 
8
  _DESCRIPTION = '# [AudioGPT](https://github.com/AIGC-Audio/AudioGPT)'
9
+ _DESCRIPTION += '\n<p>This is a demo to the work <a href="https://github.com/AIGC-Audio/AudioGPT" style="text-decoration: underline;" target="_blank">AudioGPT: Sending and Receiving Speech, Sing, Audio, and Talking head during chatting</a>. </p>'
10
  _DESCRIPTION += '\n<p>This model can only be used for non-commercial purposes.'
11
  if (SPACE_ID := os.getenv('SPACE_ID')) is not None:
12
  _DESCRIPTION += f'\n<p>For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings. <a href="https://huggingface.co/spaces/{SPACE_ID}?duplicate=true"><img style="display: inline; margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space" /></a></p>'
 
43
  New input: {input}
44
  Thought: Do I need to use a tool? {agent_scratchpad}"""
45
 
46
+
47
+ def cut_dialogue_history(history_memory, keep_last_n_words=400):
48
+ if history_memory is None or len(history_memory) == 0:
49
+ return history_memory
50
  tokens = history_memory.split()
51
  n_tokens = len(tokens)
52
  print(f"history_memory:{history_memory}, n_tokens: {n_tokens}")
53
  if n_tokens < keep_last_n_words:
54
  return history_memory
55
+ paragraphs = history_memory.split('\n')
56
+ last_n_tokens = n_tokens
57
+ while last_n_tokens >= keep_last_n_words:
58
+ last_n_tokens -= len(paragraphs[0].split(' '))
59
+ paragraphs = paragraphs[1:]
60
+ return '\n' + '\n'.join(paragraphs)
61
+
62
+
63
 
64
  class ConversationBot:
65
  def __init__(self, load_dict):
 
69
  self.models = dict()
70
  for class_name, device in load_dict.items():
71
  self.models[class_name] = globals()[class_name](device=device)
 
 
 
 
 
72
 
73
  def run_text(self, text, state):
74
  print("===============Running run_text =============")
 
81
  response = res['output']
82
  state = state + [(text, response)]
83
  print("Outputs:", state)
84
+ return state, state, gr.Audio.update(visible=False), gr.Video.update(visible=False), gr.Image.update(visible=False), gr.Button.update(visible=False)
85
  else:
86
  tool = res['intermediate_steps'][0][0].tool
87
  if tool == "Generate Image From User Input Text":
 
90
  state = state + [(text, response)]
91
  print(f"\nProcessed run_text, Input text: {text}\nCurrent state: {state}\n"
92
  f"Current Memory: {self.agent.memory.buffer}")
93
+ return state, state, gr.Audio.update(visible=False), gr.Video.update(visible=False), gr.Image.update(visible=False), gr.Button.update(visible=False)
94
  elif tool == "Detect The Sound Event From The Audio":
95
  image_filename = res['intermediate_steps'][0][1]
96
  response = res['output'] + f"![](/file={image_filename})*{image_filename}*"
97
  state = state + [(text, response)]
98
  print(f"\nProcessed run_text, Input text: {text}\nCurrent state: {state}\n"
99
  f"Current Memory: {self.agent.memory.buffer}")
100
+ return state, state, gr.Audio.update(visible=False), gr.Video.update(visible=False), gr.Image.update(visible=False), gr.Button.update(visible=False)
101
  elif tool == "Generate Text From The Audio" or tool == "Transcribe speech" or tool == "Target Sound Detection":
102
  print("======>Current memory:\n %s" % self.agent.memory)
103
  response = re.sub('(image/\S*png)', lambda m: f'![](/file={m.group(0)})*{m.group(0)}*', res['output'])
 
105
  #response = res['output'] + f"![](/file={image_filename})*{image_filename}*"
106
  state = state + [(text, response)]
107
  print("Outputs:", state)
108
+ return state, state, gr.Audio.update(visible=False), gr.Video.update(visible=False), gr.Image.update(visible=False), gr.Button.update(visible=False)
109
  elif tool == "Audio Inpainting":
110
  audio_filename = res['intermediate_steps'][0][0].tool_input
111
  image_filename = res['intermediate_steps'][0][1]
112
  print("======>Current memory:\n %s" % self.agent.memory)
 
113
  response = res['output']
114
  state = state + [(text, response)]
115
  print("Outputs:", state)
116
+ return state, state, gr.Audio.update(value=audio_filename,visible=True), gr.Video.update(visible=False), gr.Image.update(value=image_filename,visible=True), gr.Button.update(visible=True)
117
  print("======>Current memory:\n %s" % self.agent.memory)
118
  response = re.sub('(image/\S*png)', lambda m: f'![](/file={m.group(0)})*{m.group(0)}*', res['output'])
119
  audio_filename = res['intermediate_steps'][0][1]
120
  state = state + [(text, response)]
121
  print("Outputs:", state)
122
+ return state, state, gr.Audio.update(value=audio_filename,visible=True), gr.Video.update(visible=False), gr.Image.update(visible=False), gr.Button.update(visible=False)
123
 
124
  def run_image_or_audio(self, file, state, txt):
125
  file_type = file.name[-3:]
 
128
  print("Inputs:", file, state)
129
  print("======>Previous memory:\n %s" % self.agent.memory)
130
  audio_filename = os.path.join('audio', str(uuid.uuid4())[0:8] + ".wav")
131
+ # audio_load = whisper.load_audio(file.name)
132
+ audio_load, sr = soundfile.read(file.name)
133
+ soundfile.write(audio_filename, audio_load, samplerate = sr)
134
  description = self.models['A2T'].inference(audio_filename)
135
  Human_prompt = "\nHuman: provide an audio named {}. The description is: {}. This information helps you to understand this audio, but you should use tools to finish following tasks, " \
136
  "rather than directly imagine from my description. If you understand, say \"Received\". \n".format(audio_filename, description)
 
142
  #state = state + [(f"<audio src=audio_filename controls=controls></audio>*{audio_filename}*", AI_prompt)]
143
  state = state + [(f"*{audio_filename}*", AI_prompt)]
144
  print("Outputs:", state)
145
+ return state, state, gr.Audio.update(value=audio_filename,visible=True), gr.Video.update(visible=False)
146
  else:
147
  # print("===============Running run_image =============")
148
  # print("Inputs:", file, state)
 
168
  state = state + [(f"![](/file={image_filename})*{image_filename}*", AI_prompt)]
169
  print(f"\nProcessed run_image, Input image: {image_filename}\nCurrent state: {state}\n"
170
  f"Current Memory: {self.agent.memory.buffer}")
171
+ return state, state, gr.Audio.update(visible=False), gr.Video.update(visible=False)
172
+
173
+ def speech(self, speech_input, state):
174
+ input_audio_filename = os.path.join('audio', str(uuid.uuid4())[0:8] + ".wav")
175
+ text = self.models['ASR'].translate_english(speech_input)
176
+ print("Inputs:", text, state)
177
+ print("======>Previous memory:\n %s" % self.agent.memory)
178
+ self.agent.memory.buffer = cut_dialogue_history(self.agent.memory.buffer, keep_last_n_words=500)
179
+ res = self.agent({"input": text})
180
+ if res['intermediate_steps'] == []:
181
+ print("======>Current memory:\n %s" % self.agent.memory)
182
+ response = res['output']
183
+ output_audio_filename = self.models['TTS'].inference(response)
184
+ state = state + [(text, response)]
185
+ print("Outputs:", state)
186
+ return gr.Audio.update(value=None), gr.Audio.update(value=output_audio_filename,visible=True), state, gr.Video.update(visible=False)
187
+ else:
188
+ tool = res['intermediate_steps'][0][0].tool
189
+ if tool == "Generate Image From User Input Text" or tool == "Generate Text From The Audio" or tool == "Target Sound Detection":
190
+ print("======>Current memory:\n %s" % self.agent.memory)
191
+ response = re.sub('(image/\S*png)', lambda m: f'![](/file={m.group(0)})*{m.group(0)}*', res['output'])
192
+ output_audio_filename = self.models['TTS'].inference(res['output'])
193
+ state = state + [(text, response)]
194
+ print("Outputs:", state)
195
+ return gr.Audio.update(value=None), gr.Audio.update(value=output_audio_filename,visible=True), state, gr.Video.update(visible=False)
196
+ elif tool == "Transcribe Speech":
197
+ print("======>Current memory:\n %s" % self.agent.memory)
198
+ output_audio_filename = self.models['TTS'].inference(res['output'])
199
+ response = res['output']
200
+ state = state + [(text, response)]
201
+ print("Outputs:", state)
202
+ return gr.Audio.update(value=None), gr.Audio.update(value=output_audio_filename,visible=True), state, gr.Video.update(visible=False)
203
+ elif tool == "Detect The Sound Event From The Audio":
204
+ print("======>Current memory:\n %s" % self.agent.memory)
205
+ image_filename = res['intermediate_steps'][0][1]
206
+ output_audio_filename = self.models['TTS'].inference(res['output'])
207
+ response = res['output'] + f"![](/file={image_filename})*{image_filename}*"
208
+ state = state + [(text, response)]
209
+ print("Outputs:", state)
210
+ return gr.Audio.update(value=None), gr.Audio.update(value=output_audio_filename,visible=True), state, gr.Video.update(visible=False)
211
+ elif tool == "Generate a talking human portrait video given a input Audio":
212
+ video_filename = res['intermediate_steps'][0][1]
213
+ print("======>Current memory:\n %s" % self.agent.memory)
214
+ response = res['output']
215
+ output_audio_filename = self.models['TTS'].inference(res['output'])
216
+ state = state + [(text, response)]
217
+ print("Outputs:", state)
218
+ return gr.Audio.update(value=None), gr.Audio.update(value=output_audio_filename,visible=True), state, gr.Video.update(value=video_filename,visible=True)
219
+ print("======>Current memory:\n %s" % self.agent.memory)
220
+ response = re.sub('(image/\S*png)', lambda m: f'![](/file={m.group(0)})*{m.group(0)}*', res['output'])
221
+ audio_filename = res['intermediate_steps'][0][1]
222
+ Res = "The audio file has been generated and the audio is "
223
+ output_audio_filename = merge_audio(self.models['TTS'].inference(Res), audio_filename)
224
+ print(output_audio_filename)
225
+ state = state + [(text, response)]
226
+ response = res['output']
227
+ print("Outputs:", state)
228
+ return gr.Audio.update(value=None), gr.Audio.update(value=output_audio_filename,visible=True), state, gr.Video.update(visible=False)
229
 
230
  def inpainting(self, state, audio_filename, image_filename):
231
  print("===============Running inpainting =============")
232
  print("Inputs:", state)
233
  print("======>Previous memory:\n %s" % self.agent.memory)
 
234
  new_image_filename, new_audio_filename = self.models['Inpaint'].predict(audio_filename, image_filename)
235
  AI_prompt = "Here are the predict audio and the mel spectrum." + f"*{new_audio_filename}*" + f"![](/file={new_image_filename})*{new_image_filename}*"
236
  self.agent.memory.buffer = self.agent.memory.buffer + 'AI: ' + AI_prompt
 
240
  return state, state, gr.Image.update(visible=False), gr.Audio.update(value=new_audio_filename, visible=True), gr.Button.update(visible=False)
241
  def clear_audio(self):
242
  return gr.Audio.update(value=None, visible=False)
243
+ def clear_input_audio(self):
244
+ return gr.Audio.update(value=None)
245
  def clear_image(self):
246
  return gr.Image.update(value=None, visible=False)
247
+ def clear_video(self):
248
+ return gr.Video.update(value=None, visible=False)
249
  def clear_button(self):
250
  return gr.Button.update(visible=False)
251
+
252
+ def init_agent(self, openai_api_key, interaction_type):
253
+ if interaction_type == "text":
254
+ for class_name, instance in self.models.items():
255
+ for e in dir(instance):
256
+ if e.startswith('inference'):
257
+ func = getattr(instance, e)
258
+ self.tools.append(Tool(name=func.name, description=func.description, func=func))
259
+ self.llm = OpenAI(temperature=0, openai_api_key=openai_api_key)
260
+ self.agent = initialize_agent(
261
+ self.tools,
262
+ self.llm,
263
+ agent="conversational-react-description",
264
+ verbose=True,
265
+ memory=self.memory,
266
+ return_intermediate_steps=True,
267
+ agent_kwargs={'prefix': AUDIO_CHATGPT_PREFIX, 'format_instructions': AUDIO_CHATGPT_FORMAT_INSTRUCTIONS, 'suffix': AUDIO_CHATGPT_SUFFIX}, )
268
+ return gr.update(visible = False), gr.update(visible = True), gr.update(visible = True), gr.update(visible = False)
269
+ else:
270
+ for class_name, instance in self.models.items():
271
+ if class_name != 'T2A' and class_name != 'I2A' and class_name != 'Inpaint' and class_name != 'ASR' and class_name != 'SoundDetection' and class_name != 'Speech_Enh_SC' and class_name != 'Speech_SS':
272
+ for e in dir(instance):
273
+ if e.startswith('inference'):
274
+ func = getattr(instance, e)
275
+ self.tools.append(Tool(name=func.name, description=func.description, func=func))
276
+
277
+ self.llm = OpenAI(temperature=0, openai_api_key=openai_api_key)
278
+ self.agent = initialize_agent(
279
+ self.tools,
280
+ self.llm,
281
+ agent="conversational-react-description",
282
+ verbose=True,
283
+ memory=self.memory,
284
+ return_intermediate_steps=True,
285
+ agent_kwargs={'prefix': AUDIO_CHATGPT_PREFIX, 'format_instructions': AUDIO_CHATGPT_FORMAT_INSTRUCTIONS, 'suffix': AUDIO_CHATGPT_SUFFIX}, )
286
+ return gr.update(visible = False), gr.update(visible = False), gr.update(visible = False), gr.update(visible = True)
287
 
288
 
289
 
290
  if __name__ == '__main__':
291
  bot = ConversationBot({'ImageCaptioning': 'cuda:0',
292
+ 'T2A': 'cuda:0',
293
+ 'I2A': 'cuda:0',
294
  'TTS': 'cpu',
295
  'T2S': 'cpu',
296
  'ASR': 'cuda:0',
297
  'A2T': 'cpu',
298
+ 'Inpaint': 'cuda:0',
299
  'SoundDetection': 'cpu',
300
  'Binaural': 'cuda:0',
301
  'SoundExtraction': 'cuda:0',
 
303
  'Speech_Enh_SC': 'cuda:0',
304
  'Speech_SS': 'cuda:0'
305
  })
306
+ with gr.Blocks(css="#chatbot .overflow-y-auto{height:500px}") as demo:
 
 
307
  with gr.Row():
308
+ gr.Markdown("## AudioGPT")
309
+ chatbot = gr.Chatbot(elem_id="chatbot", label="AudioGPT", visible=False)
310
+ state = gr.State([])
311
+
312
+ with gr.Row() as select_raws:
313
+ with gr.Column(scale=0.7):
314
+ interaction_type = gr.Radio(choices=['text', 'speech'], value='text', label='Interaction Type')
315
  openai_api_key_textbox = gr.Textbox(
316
+ placeholder="Paste your OpenAI API key here to start AudioGPT(sk-...) and press Enter 鈫碉笍",
317
  show_label=False,
318
  lines=1,
319
  type="password",
320
  )
321
+ with gr.Row(visible=False) as text_input_raws:
 
 
 
322
  with gr.Column(scale=0.7):
323
  txt = gr.Textbox(show_label=False, placeholder="Enter text and press enter, or upload an image").style(container=False)
324
  with gr.Column(scale=0.1, min_width=0):
325
  run = gr.Button("🏃‍♂️Run")
326
  with gr.Column(scale=0.1, min_width=0):
327
+ clear_txt = gr.Button("🔄Clear️")
328
  with gr.Column(scale=0.1, min_width=0):
329
  btn = gr.UploadButton("🖼️/🎙️ Upload", file_types=["image","audio"])
330
+
331
+ with gr.Row():
332
+ outaudio = gr.Audio(visible=False)
333
+ with gr.Row():
334
+ with gr.Column(scale=0.3, min_width=0):
335
+ outvideo = gr.Video(visible=False)
336
+ with gr.Row():
337
+ show_mel = gr.Image(type="filepath",tool='sketch',visible=False)
338
+ with gr.Row():
339
+ run_button = gr.Button("Predict Masked Place",visible=False)
340
+
341
+ with gr.Row(visible=False) as speech_input_raws:
342
+ with gr.Column(scale=0.7):
343
+ speech_input = gr.Audio(source="microphone", type="filepath", label="Input")
344
+ with gr.Column(scale=0.15, min_width=0):
345
+ submit_btn = gr.Button("🏃‍♂️submit")
346
+ with gr.Column(scale=0.15, min_width=0):
347
+ clear_speech = gr.Button("🔄Clear️")
348
+ with gr.Row():
349
+ speech_output = gr.Audio(label="Output",visible=False)
350
  gr.Examples(
351
  examples=["Generate a speech with text 'here we go'",
352
  "Transcribe this speech",
 
363
  inputs=txt
364
  )
365
 
366
+ openai_api_key_textbox.submit(bot.init_agent, [openai_api_key_textbox, interaction_type], [select_raws, chatbot, text_input_raws, speech_input_raws])
367
+
368
+ txt.submit(bot.run_text, [txt, state], [chatbot, state, outaudio, outvideo, show_mel, run_button])
369
  txt.submit(lambda: "", None, txt)
370
+ run.click(bot.run_text, [txt, state], [chatbot, state, outaudio, outvideo, show_mel, run_button])
371
  run.click(lambda: "", None, txt)
372
+ btn.upload(bot.run_image_or_audio, [btn, state, txt], [chatbot, state, outaudio, outvideo])
373
+ run_button.click(bot.inpainting, [state, outaudio, show_mel], [chatbot, state, show_mel, outaudio, outvideo, run_button])
374
+ clear_txt.click(bot.memory.clear)
375
+ clear_txt.click(lambda: [], None, chatbot)
376
+ clear_txt.click(lambda: [], None, state)
377
+ clear_txt.click(lambda:None, None, txt)
378
+ clear_txt.click(bot.clear_button, None, run_button)
379
+ clear_txt.click(bot.clear_image, None, show_mel)
380
+ clear_txt.click(bot.clear_audio, None, outaudio)
381
+ clear_txt.click(bot.clear_video, None, outvideo)
382
+
383
+ submit_btn.click(bot.speech, [speech_input, state], [speech_input, speech_output, state, outvideo])
384
+ clear_speech.click(bot.clear_input_audio, None, speech_input)
385
+ clear_speech.click(bot.clear_audio, None, speech_output)
386
+ clear_speech.click(lambda: [], None, state)
387
+ clear_speech.click(bot.clear_video, None, outvideo)
388
+
389
  demo.launch(server_name="0.0.0.0", server_port=7860)