Zbot / app.py
Jabrain's picture
Update app.py
0f72fdd
raw
history blame
7.05 kB
# Import libraries
import streamlit as st
import gradio as gr
import torch
import transformers
import librosa
#import cv2
import numpy as np
# Load models
text_model = transformers.pipeline("text-generation")
audio_model = transformers.Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
audio_tokenizer = transformers.Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-base-960h")
image_model = transformers.pipeline("image-classification")
video_model = transformers.VideoClassificationPipeline(model="facebook/mmf-vit-base-16", feature_extractor="facebook/mmf-vit-base-16")
# Define functions for processing inputs and outputs
def text_to_text(input):
output = text_model(input, max_length=50)
return output[0]["generated_text"]
def text_to_audio(input):
output = text_model(input, max_length=50)
output = gr.outputs.Audio.from_str(output[0]["generated_text"])
return output
def text_to_image(input):
output = text_model(input, max_length=50)
output = gr.outputs.Image.from_str(output[0]["generated_text"])
return output
def text_to_video(input):
output = text_model(input, max_length=50)
output = gr.outputs.Video.from_str(output[0]["generated_text"])
return output
def audio_to_text(input):
input = librosa.load(input)[0]
input = torch.from_numpy(input).unsqueeze(0)
logits = audio_model(input).logits
predicted_ids = torch.argmax(logits, dim=-1)
output = audio_tokenizer.batch_decode(predicted_ids)[0]
return output
def audio_to_audio(input):
return input
def audio_to_image(input):
input = librosa.load(input)[0]
input = torch.from_numpy(input).unsqueeze(0)
logits = audio_model(input).logits
predicted_ids = torch.argmax(logits, dim=-1)
output = audio_tokenizer.batch_decode(predicted_ids)[0]
output = gr.outputs.Image.from_str(output)
return output
def audio_to_video(input):
input = librosa.load(input)[0]
input = torch.from_numpy(input).unsqueeze(0)
logits = audio_model(input).logits
predicted_ids = torch.argmax(logits, dim=-1)
output = audio_tokenizer.batch_decode(predicted_ids)[0]
output = gr.outputs.Video.from_str(output)
return output
#def image_to_text(input):
# input = cv2.imread(input)
# input = cv2.cvtColor(input, cv2.COLOR_BGR2RGB)
# input = np.expand_dims(input, axis=0)
# output = image_model(input)
# return output[0]["label"]
#def image_to_audio(input):
# input = cv2.imread(input)
# input = cv2.cvtColor(input, cv2.COLOR_BGR2RGB)
# input = np.expand_dims(input, axis=0)
# output = image_model(input)
# output = gr.outputs.Audio.from_str(output[0]["label"])
# return output
def image_to_image(input):
return input
#def image_to_video(input):
# input = cv2.imread(input)
# input = cv2.cvtColor(input, cv2.COLOR_BGR2RGB)
# input = np.expand_dims(input, axis=0)
# output = image_model(input)
# output = gr.outputs.Video.from_str(output[0]["label"])
# return output
#def video_to_text(input):
# input = cv2.VideoCapture(input)
# frames = []
# while input.isOpened():
# ret, frame = input.read()
# if ret:
# frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
# frames.append(frame)
# else:
# breakpoint
# input.release()
# frames = np.array(frames)
# output = video_model(frames)
# return output[0]["label"]
#Output switches
# Define the process_input function
def process_input(user_input, input_format, output_format):
# Use a switch case to call the appropriate function based on the input and output formats
if input_format == "Text" and output_format == "Text":
output = text_to_text(user_input)
elif input_format == "Text" and output_format == "Audio":
output = text_to_audio(user_input)
elif input_format == "Text" and output_format == "Image":
output = text_to_image(user_input)
elif input_format == "Text" and output_format == "Video":
output = text_to_video(user_input)
elif input_format == "Audio" and output_format == "Text":
output = audio_to_text(user_input)
elif input_format == "Audio" and output_format == "Audio":
output = audio_to_audio(user_input)
elif input_format == "Audio" and output_format == "Image":
output = audio_to_image(user_input)
elif input_format == "Audio" and output_format == "Video":
output = audio_to_video(user_input)
elif input_format == "Image" and output_format == "Text":
output = image_to_text(user_input)
elif input_format == "Image" and output_format == "Audio":
output = image_to_audio(user_input)
elif input_format == "Image" and output_format == "Image":
output = image_to_image(user_input)
elif input_format == "Image" and output_format == "Video":
output = image_to_video(user_input)
elif input_format == "Video" and output_format == "Text":
output = video_to_text(user_input)
elif input_format == "Video" and output_format == "Audio":
output = video_to_audio(user_input)
elif input_format == "Video" and output_format == "Image":
output = video_to_image(user_input)
elif input_format == "Video" and output_format == "Video":
output = video_to_video(user_input)
else:
output = "Invalid input or output format"
# Return the output data as a gradio output object
return output
# Create a title for the app
st.title("My Generic AI App")
# Create a sidebar for selecting the input and output formats
st.sidebar.header("Select the input and output formats")
input_format = st.sidebar.selectbox("Input format", ["Text", "Audio", "Image", "Video"])
output_format = st.sidebar.selectbox("Output format", ["Text", "Audio", "Image", "Video"])
# Create a container for the input and output widgets
io_container = st.container()
# Create a chat input widget for text input
if input_format == "Text":
user_input = st.chat_input("Type a text")
# Create a file uploader widget for audio input
elif input_format == "Audio":
user_input = st.file_uploader("Upload an audio file", type=["wav", "mp3", "ogg"])
# Create a file uploader widget for image input
elif input_format == "Image":
user_input = st.file_uploader("Upload an image file", type=["jpg", "png", "gif"])
# Create a file uploader widget for video input
else:
user_input = st.file_uploader("Upload a video file", type=["mp4", "avi", "mov"])
# Check if the user input is not empty
if user_input:
# Display the user input as a chat message or an image
with io_container:
if input_format == "Text":
st.chat_message("user", user_input)
else:
st.image(user_input, caption="User input")
# Process the user input and generate a response
# You can use your own logic or a language model here
# For example, you can use a switch case to call the appropriate function
# based on the input and output formats
response = process_input(user_input, input_format, output_format)
# Display the response as a chat message or an image
with io_container:
if output_format == "Text":
st.chat_message("assistant", response)
else:
st.image(response, caption="Assistant output")