HakimHa commited on
Commit
97f8dd4
Β·
1 Parent(s): 29ac35f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -12
app.py CHANGED
@@ -1,25 +1,31 @@
1
  import gradio as gr
2
  from PIL import Image
3
- import speech_recognition as sr
4
  from transformers import GPT2LMHeadModel, GPT2Tokenizer
 
 
5
 
6
  # Load pre-trained model and tokenizer
7
- tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
8
- model = GPT2LMHeadModel.from_pretrained("gpt2")
 
 
 
 
9
 
10
  # Placeholder function to handle text input
11
  def handle_text(text):
12
  # encode the new user input, add the eos_token and return a tensor in Pytorch
13
- new_user_input_ids = tokenizer.encode(text + tokenizer.eos_token, return_tensors='pt')
14
 
15
  # append the new user input tokens to the chat history
16
  bot_input_ids = new_user_input_ids
17
 
18
  # generate a response
19
- chat_history_ids = model.generate(bot_input_ids, max_length=1000, pad_token_id=tokenizer.eos_token_id)
20
 
21
  # Print the generated chat
22
- chat_output = tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)
23
  return chat_output
24
 
25
  # Placeholder function to handle image input
@@ -29,12 +35,16 @@ def handle_image(img):
29
 
30
  # Placeholder function to handle audio input
31
  def handle_audio(audio):
32
- # This is a placeholder function, replace with your own audio processing function
33
- r = sr.Recognizer()
34
- with sr.AudioFile(audio.name) as source:
35
- audio_data = r.record(source)
36
- text = r.recognize_google(audio_data)
37
- return handle_text(text)
 
 
 
 
38
 
39
  def chatbot(inputs):
40
  text, img, audio = inputs
 
1
  import gradio as gr
2
  from PIL import Image
3
+ from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
4
  from transformers import GPT2LMHeadModel, GPT2Tokenizer
5
+ import soundfile as sf
6
+ import torch
7
 
8
  # Load pre-trained model and tokenizer
9
+ gpt2_tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
10
+ gpt2_model = GPT2LMHeadModel.from_pretrained("gpt2")
11
+
12
+ # Load pre-trained model and processor
13
+ processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
14
+ model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
15
 
16
  # Placeholder function to handle text input
17
  def handle_text(text):
18
  # encode the new user input, add the eos_token and return a tensor in Pytorch
19
+ new_user_input_ids = gpt2_tokenizer.encode(text + gpt2_tokenizer.eos_token, return_tensors='pt')
20
 
21
  # append the new user input tokens to the chat history
22
  bot_input_ids = new_user_input_ids
23
 
24
  # generate a response
25
+ chat_history_ids = gpt2_model.generate(bot_input_ids, max_length=1000, pad_token_id=gpt2_tokenizer.eos_token_id)
26
 
27
  # Print the generated chat
28
+ chat_output = gpt2_tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)
29
  return chat_output
30
 
31
  # Placeholder function to handle image input
 
35
 
36
  # Placeholder function to handle audio input
37
  def handle_audio(audio):
38
+ # load audio
39
+ speech, _ = sf.read(audio)
40
+ # transcribe speech to text
41
+ input_values = processor(speech, return_tensors="pt").input_values
42
+ # perform forward pass
43
+ logits = model(input_values).logits
44
+ # take argmax and decode
45
+ predicted_ids = torch.argmax(logits, dim=-1)
46
+ transcriptions = processor.decode(predicted_ids[0])
47
+ return handle_text(transcriptions)
48
 
49
  def chatbot(inputs):
50
  text, img, audio = inputs