|
|
|
import streamlit as st |
|
import gradio as gr |
|
import torch |
|
import transformers |
|
import librosa |
|
import cv2 |
|
import numpy as np |
|
|
|
|
|
text_model = transformers.pipeline("text-generation") |
|
audio_model = transformers.Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h") |
|
audio_tokenizer = transformers.Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-base-960h") |
|
image_model = transformers.pipeline("image-classification") |
|
video_model = transformers.VideoClassificationPipeline(model="facebook/mmf-vit-base-16", feature_extractor="facebook/mmf-vit-base-16") |
|
|
|
|
|
def text_to_text(input): |
|
output = text_model(input, max_length=50) |
|
return output[0]["generated_text"] |
|
|
|
def text_to_audio(input): |
|
output = text_model(input, max_length=50) |
|
output = gr.outputs.Audio.from_str(output[0]["generated_text"]) |
|
return output |
|
|
|
def text_to_image(input): |
|
output = text_model(input, max_length=50) |
|
output = gr.outputs.Image.from_str(output[0]["generated_text"]) |
|
return output |
|
|
|
def text_to_video(input): |
|
output = text_model(input, max_length=50) |
|
output = gr.outputs.Video.from_str(output[0]["generated_text"]) |
|
return output |
|
|
|
def audio_to_text(input): |
|
input = librosa.load(input)[0] |
|
input = torch.from_numpy(input).unsqueeze(0) |
|
logits = audio_model(input).logits |
|
predicted_ids = torch.argmax(logits, dim=-1) |
|
output = audio_tokenizer.batch_decode(predicted_ids)[0] |
|
return output |
|
|
|
def audio_to_audio(input): |
|
return input |
|
|
|
def audio_to_image(input): |
|
input = librosa.load(input)[0] |
|
input = torch.from_numpy(input).unsqueeze(0) |
|
logits = audio_model(input).logits |
|
predicted_ids = torch.argmax(logits, dim=-1) |
|
output = audio_tokenizer.batch_decode(predicted_ids)[0] |
|
output = gr.outputs.Image.from_str(output) |
|
return output |
|
|
|
def audio_to_video(input): |
|
input = librosa.load(input)[0] |
|
input = torch.from_numpy(input).unsqueeze(0) |
|
logits = audio_model(input).logits |
|
predicted_ids = torch.argmax(logits, dim=-1) |
|
output = audio_tokenizer.batch_decode(predicted_ids)[0] |
|
output = gr.outputs.Video.from_str(output) |
|
return output |
|
|
|
def image_to_text(input): |
|
input = cv2.imread(input) |
|
input = cv2.cvtColor(input, cv2.COLOR_BGR2RGB) |
|
input = np.expand_dims(input, axis=0) |
|
output = image_model(input) |
|
return output[0]["label"] |
|
|
|
def image_to_audio(input): |
|
input = cv2.imread(input) |
|
input = cv2.cvtColor(input, cv2.COLOR_BGR2RGB) |
|
input = np.expand_dims(input, axis=0) |
|
output = image_model(input) |
|
output = gr.outputs.Audio.from_str(output[0]["label"]) |
|
return output |
|
|
|
def image_to_image(input): |
|
return input |
|
|
|
def image_to_video(input): |
|
input = cv2.imread(input) |
|
input = cv2.cvtColor(input, cv2.COLOR_BGR2RGB) |
|
input = np.expand_dims(input, axis=0) |
|
output = image_model(input) |
|
output = gr.outputs.Video.from_str(output[0]["label"]) |
|
return output |
|
|
|
def video_to_text(input): |
|
input = cv2.VideoCapture(input) |
|
frames = [] |
|
|