|
|
|
import streamlit as st |
|
import gradio as gr |
|
import torch |
|
import transformers |
|
import librosa |
|
|
|
import numpy as np |
|
|
|
|
|
text_model = transformers.pipeline("text-generation") |
|
audio_model = transformers.Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h") |
|
audio_tokenizer = transformers.Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-base-960h") |
|
image_model = transformers.pipeline("image-classification") |
|
video_model = transformers.VideoClassificationPipeline(model="facebook/mmf-vit-base-16", feature_extractor="facebook/mmf-vit-base-16") |
|
|
|
|
|
def text_to_text(input): |
|
output = text_model(input, max_length=50) |
|
return output[0]["generated_text"] |
|
|
|
def text_to_audio(input): |
|
output = text_model(input, max_length=50) |
|
output = gr.outputs.Audio.from_str(output[0]["generated_text"]) |
|
return output |
|
|
|
def text_to_image(input): |
|
output = text_model(input, max_length=50) |
|
output = gr.outputs.Image.from_str(output[0]["generated_text"]) |
|
return output |
|
|
|
def text_to_video(input): |
|
output = text_model(input, max_length=50) |
|
output = gr.outputs.Video.from_str(output[0]["generated_text"]) |
|
return output |
|
|
|
def audio_to_text(input): |
|
input = librosa.load(input)[0] |
|
input = torch.from_numpy(input).unsqueeze(0) |
|
logits = audio_model(input).logits |
|
predicted_ids = torch.argmax(logits, dim=-1) |
|
output = audio_tokenizer.batch_decode(predicted_ids)[0] |
|
return output |
|
|
|
def audio_to_audio(input): |
|
return input |
|
|
|
def audio_to_image(input): |
|
input = librosa.load(input)[0] |
|
input = torch.from_numpy(input).unsqueeze(0) |
|
logits = audio_model(input).logits |
|
predicted_ids = torch.argmax(logits, dim=-1) |
|
output = audio_tokenizer.batch_decode(predicted_ids)[0] |
|
output = gr.outputs.Image.from_str(output) |
|
return output |
|
|
|
def audio_to_video(input): |
|
input = librosa.load(input)[0] |
|
input = torch.from_numpy(input).unsqueeze(0) |
|
logits = audio_model(input).logits |
|
predicted_ids = torch.argmax(logits, dim=-1) |
|
output = audio_tokenizer.batch_decode(predicted_ids)[0] |
|
output = gr.outputs.Video.from_str(output) |
|
return output |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def image_to_image(input): |
|
return input |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def process_input(user_input, input_format, output_format): |
|
|
|
|
|
if input_format == "Text" and output_format == "Text": |
|
output = text_to_text(user_input) |
|
elif input_format == "Text" and output_format == "Audio": |
|
output = text_to_audio(user_input) |
|
elif input_format == "Text" and output_format == "Image": |
|
output = text_to_image(user_input) |
|
elif input_format == "Text" and output_format == "Video": |
|
output = text_to_video(user_input) |
|
elif input_format == "Audio" and output_format == "Text": |
|
output = audio_to_text(user_input) |
|
elif input_format == "Audio" and output_format == "Audio": |
|
output = audio_to_audio(user_input) |
|
elif input_format == "Audio" and output_format == "Image": |
|
output = audio_to_image(user_input) |
|
elif input_format == "Audio" and output_format == "Video": |
|
output = audio_to_video(user_input) |
|
elif input_format == "Image" and output_format == "Text": |
|
output = image_to_text(user_input) |
|
elif input_format == "Image" and output_format == "Audio": |
|
output = image_to_audio(user_input) |
|
elif input_format == "Image" and output_format == "Image": |
|
output = image_to_image(user_input) |
|
elif input_format == "Image" and output_format == "Video": |
|
output = image_to_video(user_input) |
|
elif input_format == "Video" and output_format == "Text": |
|
output = video_to_text(user_input) |
|
elif input_format == "Video" and output_format == "Audio": |
|
output = video_to_audio(user_input) |
|
elif input_format == "Video" and output_format == "Image": |
|
output = video_to_image(user_input) |
|
elif input_format == "Video" and output_format == "Video": |
|
output = video_to_video(user_input) |
|
else: |
|
output = "Invalid input or output format" |
|
|
|
|
|
return output |
|
|
|
|
|
|
|
st.title("My Generic AI App") |
|
|
|
|
|
st.sidebar.header("Select the input and output formats") |
|
input_format = st.sidebar.selectbox("Input format", ["Text", "Audio", "Image", "Video"]) |
|
output_format = st.sidebar.selectbox("Output format", ["Text", "Audio", "Image", "Video"]) |
|
|
|
|
|
io_container = st.container() |
|
|
|
|
|
if input_format == "Text": |
|
user_input = st.chat_input("Type a text") |
|
|
|
|
|
elif input_format == "Audio": |
|
user_input = st.file_uploader("Upload an audio file", type=["wav", "mp3", "ogg"]) |
|
|
|
|
|
elif input_format == "Image": |
|
user_input = st.file_uploader("Upload an image file", type=["jpg", "png", "gif"]) |
|
|
|
|
|
else: |
|
user_input = st.file_uploader("Upload a video file", type=["mp4", "avi", "mov"]) |
|
|
|
|
|
if user_input: |
|
|
|
|
|
with io_container: |
|
if input_format == "Text": |
|
st.chat_message("user", user_input) |
|
else: |
|
st.image(user_input, caption="User input") |
|
|
|
|
|
|
|
|
|
|
|
response = process_input(user_input, input_format, output_format) |
|
|
|
|
|
with io_container: |
|
if output_format == "Text": |
|
st.chat_message("assistant", response) |
|
else: |
|
st.image(response, caption="Assistant output") |
|
|
|
|