File size: 3,701 Bytes
196ede5 830721f ebba648 50d8db6 97f8dd4 a9fada1 ebba648 830721f ebba648 50d8db6 4a4ba82 ebba648 50d8db6 49c04e3 a228bae 3442c7b ebba648 3442c7b ebba648 830721f 6726617 50d8db6 4263327 830721f ebba648 830721f 4263327 50d8db6 830721f 50d8db6 ebba648 50d8db6 4263327 a9fada1 ebba648 9378660 a9fada1 ebba648 50d8db6 ebba648 33b9657 db16db9 5744512 29ac35f ebba648 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 |
import gradio as gr
from PIL import Image
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer, Wav2Vec2Processor, Wav2Vec2ForCTC, ViTFeatureExtractor, ViTForImageClassification
import soundfile as sf
import torch
import numpy as np
import time
# Initialize the transformers and the models
class_names = {
0: "al qarawiyyin",
1: "bab mansour el aleuj",
2: "chaouara tannery",
3: "hassan tower",
4: "jamae el fna",
5: "koutoubia mosque",
6: "madrasa ben youssef",
7: "majorel gardens",
8: "menara"
}
model_name_or_path = "microsoft/DialoGPT-large"
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, padding_side="left", use_fast=False)
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, torch_dtype=torch.float32, trust_remote_code=True)
wav2vec2_processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
wav2vec2_model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
vit_model = ViTForImageClassification.from_pretrained('ohidaoui/monuments-morocco-v1')
vit_feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224')
# Function to handle text input
def handle_text(text):
chat_output = chat({"question": text})
return chat_output["answer"]
# Function to handle image input
def get_class_name(class_idx):
return class_names[class_idx]
def handle_image(img):
img = np.array(img)
inputs = vit_feature_extractor(images=img, return_tensors="pt")
outputs = vit_model(**inputs)
predicted_class_idx = torch.argmax(outputs.logits, dim=1).item()
predicted_class_name = get_class_name(predicted_class_idx)
chat_output = chat({"question": "what is " + predicted_class_name})
return chat_output["answer"]
# Function to handle audio input
def handle_audio(audio):
audio = audio[1]
input_values = wav2vec2_processor(audio, sampling_rate=16_000, return_tensors="pt").input_values
input_values = input_values.to(torch.float32)
logits = wav2vec2_model(input_values).logits
predicted_ids = torch.argmax(logits, dim=-1)
transcriptions = wav2vec2_processor.decode(predicted_ids[0])
chat_output = chat({"question": transcriptions})
return chat_output["answer"]
# Main function to handle the inputs
def chatbot(history, text=None, img=None, audio=None):
text_output = handle_text(text) if text is not None else ''
img_output = handle_image(img) if img is not None else ''
audio_output = handle_audio(audio) if audio is not None else ''
outputs = [o for o in [text_output, img_output, audio_output] if o]
output = "\n".join(outputs)
history[-1][1] = output
for character in output:
history[-1][1] += character
time.sleep(0.05)
yield history
with gr.Blocks() as demo:
chat_interface = gr.Chatbot([], elem_id="chatbot", height=750)
with gr.Row():
with gr.Column(scale=0.85):
text_input = gr.Textbox(
show_label=False,
placeholder="Input Text here...",
container=False
)
with gr.Column(scale=0.15, min_width=0):
img_input = gr.Image()
audio_input = gr.Audio(source="microphone", label="Audio Input")
text_msg = text_input.submit(chatbot, [chat_interface, text_input], [chat_interface, text_input], queue=False)
img_msg = img_input.upload(chatbot, [chat_interface, img_input], [chat_interface, img_input], queue=False)
audio_msg = audio_input.upload(chatbot, [chat_interface, audio_input], [chat_interface, audio_input], queue=False)
demo.queue()
demo.launch(share=True) |