Zbot / app.py
Jabrain's picture
Update app.py
97fea5d
raw
history blame
2.95 kB
# Import libraries
import streamlit as st
import gradio as gr
import torch
import transformers
import librosa
import cv2
import numpy as np
# Load models
text_model = transformers.pipeline("text-generation")
audio_model = transformers.Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
audio_tokenizer = transformers.Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-base-960h")
image_model = transformers.pipeline("image-classification")
video_model = transformers.VideoClassificationPipeline(model="facebook/mmf-vit-base-16", feature_extractor="facebook/mmf-vit-base-16")
# Define functions for processing inputs and outputs
def text_to_text(input):
output = text_model(input, max_length=50)
return output[0]["generated_text"]
def text_to_audio(input):
output = text_model(input, max_length=50)
output = gr.outputs.Audio.from_str(output[0]["generated_text"])
return output
def text_to_image(input):
output = text_model(input, max_length=50)
output = gr.outputs.Image.from_str(output[0]["generated_text"])
return output
def text_to_video(input):
output = text_model(input, max_length=50)
output = gr.outputs.Video.from_str(output[0]["generated_text"])
return output
def audio_to_text(input):
input = librosa.load(input)[0]
input = torch.from_numpy(input).unsqueeze(0)
logits = audio_model(input).logits
predicted_ids = torch.argmax(logits, dim=-1)
output = audio_tokenizer.batch_decode(predicted_ids)[0]
return output
def audio_to_audio(input):
return input
def audio_to_image(input):
input = librosa.load(input)[0]
input = torch.from_numpy(input).unsqueeze(0)
logits = audio_model(input).logits
predicted_ids = torch.argmax(logits, dim=-1)
output = audio_tokenizer.batch_decode(predicted_ids)[0]
output = gr.outputs.Image.from_str(output)
return output
def audio_to_video(input):
input = librosa.load(input)[0]
input = torch.from_numpy(input).unsqueeze(0)
logits = audio_model(input).logits
predicted_ids = torch.argmax(logits, dim=-1)
output = audio_tokenizer.batch_decode(predicted_ids)[0]
output = gr.outputs.Video.from_str(output)
return output
def image_to_text(input):
input = cv2.imread(input)
input = cv2.cvtColor(input, cv2.COLOR_BGR2RGB)
input = np.expand_dims(input, axis=0)
output = image_model(input)
return output[0]["label"]
def image_to_audio(input):
input = cv2.imread(input)
input = cv2.cvtColor(input, cv2.COLOR_BGR2RGB)
input = np.expand_dims(input, axis=0)
output = image_model(input)
output = gr.outputs.Audio.from_str(output[0]["label"])
return output
def image_to_image(input):
return input
def image_to_video(input):
input = cv2.imread(input)
input = cv2.cvtColor(input, cv2.COLOR_BGR2RGB)
input = np.expand_dims(input, axis=0)
output = image_model(input)
output = gr.outputs.Video.from_str(output[0]["label"])
return output
def video_to_text(input):
input = cv2.VideoCapture(input)
frames = []