wanderJoy / app.py
HakimHa's picture
Update app.py
97f8dd4
raw
history blame
2.63 kB
import gradio as gr
from PIL import Image
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import soundfile as sf
import torch
# Load pre-trained model and tokenizer
gpt2_tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
gpt2_model = GPT2LMHeadModel.from_pretrained("gpt2")
# Load pre-trained model and processor
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
# Placeholder function to handle text input
def handle_text(text):
# encode the new user input, add the eos_token and return a tensor in Pytorch
new_user_input_ids = gpt2_tokenizer.encode(text + gpt2_tokenizer.eos_token, return_tensors='pt')
# append the new user input tokens to the chat history
bot_input_ids = new_user_input_ids
# generate a response
chat_history_ids = gpt2_model.generate(bot_input_ids, max_length=1000, pad_token_id=gpt2_tokenizer.eos_token_id)
# Print the generated chat
chat_output = gpt2_tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)
return chat_output
# Placeholder function to handle image input
def handle_image(img):
# This is a placeholder function, replace with your own image processing function
return "This image seems nice!"
# Placeholder function to handle audio input
def handle_audio(audio):
# load audio
speech, _ = sf.read(audio)
# transcribe speech to text
input_values = processor(speech, return_tensors="pt").input_values
# perform forward pass
logits = model(input_values).logits
# take argmax and decode
predicted_ids = torch.argmax(logits, dim=-1)
transcriptions = processor.decode(predicted_ids[0])
return handle_text(transcriptions)
def chatbot(inputs):
text, img, audio = inputs
text_output = handle_text(text) if text else None
img_output = handle_image(img) if img else None
audio_output = handle_audio(audio) if audio else None
outputs = [o for o in [text_output, img_output, audio_output] if o]
return "\n".join(outputs)
# Define the Gradio interface
iface = gr.Interface(
fn=chatbot,
inputs=[
gr.inputs.Textbox(lines=2, placeholder="Input Text here..."),
gr.inputs.Image(label="Upload Image"),
gr.inputs.Audio(label="Input Audio"),
],
outputs=gr.outputs.Textbox(label="Output"),
title="Multimodal Chatbot",
description="This chatbot can handle text, image, and audio inputs. Try it out!",
)
# Launch the Gradio interface
iface.launch()