HakimHa commited on
Commit
3442c7b
Β·
1 Parent(s): 4263327

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -24
app.py CHANGED
@@ -1,47 +1,42 @@
1
  import gradio as gr
2
  from PIL import Image
3
- from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
4
- from transformers import GPT2LMHeadModel, GPT2Tokenizer
5
  import soundfile as sf
6
  import torch
7
 
8
- # Load pre-trained model and tokenizer for GPT2
9
- gpt2_tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
10
- gpt2_model = GPT2LMHeadModel.from_pretrained("gpt2")
 
 
 
 
 
 
 
 
11
 
12
  # Load pre-trained model and processor for Wav2Vec2
13
  processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
14
- model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
15
 
16
  # Function to handle text input
17
  def handle_text(text):
18
- # encode the new user input, add the eos_token and return a tensor in Pytorch
19
- new_user_input_ids = gpt2_tokenizer.encode(text + gpt2_tokenizer.eos_token, return_tensors='pt')
20
-
21
- # append the new user input tokens to the chat history
22
  bot_input_ids = new_user_input_ids
23
-
24
- # generate a response
25
- chat_history_ids = gpt2_model.generate(bot_input_ids, max_length=1000, pad_token_id=gpt2_tokenizer.eos_token_id)
26
-
27
- # Print the generated chat
28
- chat_output = gpt2_tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)
29
  return chat_output
30
 
31
  # Function to handle image input
32
  def handle_image(img):
33
- # This is a placeholder function, replace with your own image processing function
34
  return "This image seems nice!"
35
 
36
  # Function to handle audio input
37
  def handle_audio(audio):
38
- # load audio
39
  speech, _ = sf.read(audio)
40
- # transcribe speech to text
41
  input_values = processor(speech, return_tensors="pt").input_values
42
- # perform forward pass
43
- logits = model(input_values).logits
44
- # take argmax and decode
45
  predicted_ids = torch.argmax(logits, dim=-1)
46
  transcriptions = processor.decode(predicted_ids[0])
47
  return handle_text(transcriptions)
@@ -54,7 +49,6 @@ def chatbot(text, img, audio):
54
  outputs = [o for o in [text_output, img_output, audio_output] if o]
55
  return "\n".join(outputs)
56
 
57
- # Define the Gradio interface
58
  iface = gr.Interface(
59
  fn=chatbot,
60
  inputs=[
@@ -67,5 +61,4 @@ iface = gr.Interface(
67
  description="This chatbot can handle text, image, and audio inputs. Try it out!",
68
  )
69
 
70
- # Launch the Gradio interface
71
  iface.launch()
 
1
  import gradio as gr
2
  from PIL import Image
3
+ from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC, AutoModelForCausalLM, AutoTokenizer
 
4
  import soundfile as sf
5
  import torch
6
 
7
+ model_name_or_path = "bofenghuang/vigogne-falcon-7b-chat"
8
+
9
+ tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, padding_side="right", use_fast=False)
10
+ tokenizer.pad_token = tokenizer.eos_token
11
+
12
+ model = AutoModelForCausalLM.from_pretrained(
13
+ model_name_or_path,
14
+ torch_dtype=torch.float16,
15
+ device_map="auto",
16
+ trust_remote_code=True,
17
+ )
18
 
19
  # Load pre-trained model and processor for Wav2Vec2
20
  processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
21
+ wav2vec2_model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
22
 
23
  # Function to handle text input
24
  def handle_text(text):
25
+ new_user_input_ids = tokenizer.encode(text + tokenizer.eos_token, return_tensors='pt')
 
 
 
26
  bot_input_ids = new_user_input_ids
27
+ chat_history_ids = model.generate(bot_input_ids, max_length=1000, pad_token_id=tokenizer.eos_token_id)
28
+ chat_output = tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)
 
 
 
 
29
  return chat_output
30
 
31
  # Function to handle image input
32
  def handle_image(img):
 
33
  return "This image seems nice!"
34
 
35
  # Function to handle audio input
36
  def handle_audio(audio):
 
37
  speech, _ = sf.read(audio)
 
38
  input_values = processor(speech, return_tensors="pt").input_values
39
+ logits = wav2vec2_model(input_values).logits
 
 
40
  predicted_ids = torch.argmax(logits, dim=-1)
41
  transcriptions = processor.decode(predicted_ids[0])
42
  return handle_text(transcriptions)
 
49
  outputs = [o for o in [text_output, img_output, audio_output] if o]
50
  return "\n".join(outputs)
51
 
 
52
  iface = gr.Interface(
53
  fn=chatbot,
54
  inputs=[
 
61
  description="This chatbot can handle text, image, and audio inputs. Try it out!",
62
  )
63
 
 
64
  iface.launch()