|
import gradio as gr |
|
from PIL import Image |
|
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC |
|
from transformers import GPT2LMHeadModel, GPT2Tokenizer |
|
import soundfile as sf |
|
import torch |
|
|
|
|
|
gpt2_tokenizer = GPT2Tokenizer.from_pretrained("gpt2") |
|
gpt2_model = GPT2LMHeadModel.from_pretrained("gpt2") |
|
|
|
|
|
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h") |
|
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h") |
|
|
|
|
|
def handle_text(text): |
|
|
|
new_user_input_ids = gpt2_tokenizer.encode(text + gpt2_tokenizer.eos_token, return_tensors='pt') |
|
|
|
|
|
bot_input_ids = new_user_input_ids |
|
|
|
|
|
chat_history_ids = gpt2_model.generate(bot_input_ids, max_length=1000, pad_token_id=gpt2_tokenizer.eos_token_id) |
|
|
|
|
|
chat_output = gpt2_tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True) |
|
return chat_output |
|
|
|
|
|
def handle_image(img): |
|
|
|
return "This image seems nice!" |
|
|
|
|
|
def handle_audio(audio): |
|
|
|
speech, _ = sf.read(audio) |
|
|
|
input_values = processor(speech, return_tensors="pt").input_values |
|
|
|
logits = model(input_values).logits |
|
|
|
predicted_ids = torch.argmax(logits, dim=-1) |
|
transcriptions = processor.decode(predicted_ids[0]) |
|
return handle_text(transcriptions) |
|
|
|
def chatbot(inputs): |
|
text, img, audio = inputs |
|
text_output = handle_text(text) if text else None |
|
img_output = handle_image(img) if img else None |
|
audio_output = handle_audio(audio) if audio else None |
|
|
|
outputs = [o for o in [text_output, img_output, audio_output] if o] |
|
return "\n".join(outputs) |
|
|
|
|
|
iface = gr.Interface( |
|
fn=chatbot, |
|
inputs=[ |
|
gr.inputs.Textbox(lines=2, placeholder="Input Text here..."), |
|
gr.inputs.Image(label="Upload Image"), |
|
gr.inputs.Audio(label="Input Audio"), |
|
], |
|
outputs=gr.outputs.Textbox(label="Output"), |
|
title="Multimodal Chatbot", |
|
description="This chatbot can handle text, image, and audio inputs. Try it out!", |
|
) |
|
|
|
|
|
iface.launch() |
|
|