Update app.py
Browse files
app.py
CHANGED
@@ -1,25 +1,31 @@
|
|
1 |
import gradio as gr
|
2 |
from PIL import Image
|
3 |
-
import
|
4 |
from transformers import GPT2LMHeadModel, GPT2Tokenizer
|
|
|
|
|
5 |
|
6 |
# Load pre-trained model and tokenizer
|
7 |
-
|
8 |
-
|
|
|
|
|
|
|
|
|
9 |
|
10 |
# Placeholder function to handle text input
|
11 |
def handle_text(text):
|
12 |
# encode the new user input, add the eos_token and return a tensor in Pytorch
|
13 |
-
new_user_input_ids =
|
14 |
|
15 |
# append the new user input tokens to the chat history
|
16 |
bot_input_ids = new_user_input_ids
|
17 |
|
18 |
# generate a response
|
19 |
-
chat_history_ids =
|
20 |
|
21 |
# Print the generated chat
|
22 |
-
chat_output =
|
23 |
return chat_output
|
24 |
|
25 |
# Placeholder function to handle image input
|
@@ -29,12 +35,16 @@ def handle_image(img):
|
|
29 |
|
30 |
# Placeholder function to handle audio input
|
31 |
def handle_audio(audio):
|
32 |
-
#
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
|
|
|
|
|
|
|
|
38 |
|
39 |
def chatbot(inputs):
|
40 |
text, img, audio = inputs
|
|
|
1 |
import gradio as gr
|
2 |
from PIL import Image
|
3 |
+
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
|
4 |
from transformers import GPT2LMHeadModel, GPT2Tokenizer
|
5 |
+
import soundfile as sf
|
6 |
+
import torch
|
7 |
|
8 |
# Load pre-trained model and tokenizer
|
9 |
+
gpt2_tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
|
10 |
+
gpt2_model = GPT2LMHeadModel.from_pretrained("gpt2")
|
11 |
+
|
12 |
+
# Load pre-trained model and processor
|
13 |
+
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
|
14 |
+
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
|
15 |
|
16 |
# Placeholder function to handle text input
|
17 |
def handle_text(text):
|
18 |
# encode the new user input, add the eos_token and return a tensor in Pytorch
|
19 |
+
new_user_input_ids = gpt2_tokenizer.encode(text + gpt2_tokenizer.eos_token, return_tensors='pt')
|
20 |
|
21 |
# append the new user input tokens to the chat history
|
22 |
bot_input_ids = new_user_input_ids
|
23 |
|
24 |
# generate a response
|
25 |
+
chat_history_ids = gpt2_model.generate(bot_input_ids, max_length=1000, pad_token_id=gpt2_tokenizer.eos_token_id)
|
26 |
|
27 |
# Print the generated chat
|
28 |
+
chat_output = gpt2_tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)
|
29 |
return chat_output
|
30 |
|
31 |
# Placeholder function to handle image input
|
|
|
35 |
|
36 |
# Placeholder function to handle audio input
|
37 |
def handle_audio(audio):
|
38 |
+
# load audio
|
39 |
+
speech, _ = sf.read(audio)
|
40 |
+
# transcribe speech to text
|
41 |
+
input_values = processor(speech, return_tensors="pt").input_values
|
42 |
+
# perform forward pass
|
43 |
+
logits = model(input_values).logits
|
44 |
+
# take argmax and decode
|
45 |
+
predicted_ids = torch.argmax(logits, dim=-1)
|
46 |
+
transcriptions = processor.decode(predicted_ids[0])
|
47 |
+
return handle_text(transcriptions)
|
48 |
|
49 |
def chatbot(inputs):
|
50 |
text, img, audio = inputs
|