File size: 3,701 Bytes
196ede5
830721f
ebba648
50d8db6
97f8dd4
 
a9fada1
ebba648
830721f
ebba648
50d8db6
4a4ba82
 
 
 
 
 
 
 
 
ebba648
50d8db6
49c04e3
a228bae
3442c7b
ebba648
3442c7b
ebba648
 
830721f
6726617
 
50d8db6
4263327
830721f
ebba648
 
830721f
4263327
50d8db6
 
 
830721f
50d8db6
 
 
 
 
ebba648
 
50d8db6
4263327
a9fada1
 
ebba648
9378660
a9fada1
 
 
ebba648
 
50d8db6
ebba648
 
33b9657
db16db9
5744512
29ac35f
ebba648
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import gradio as gr
from PIL import Image
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer, Wav2Vec2Processor, Wav2Vec2ForCTC, ViTFeatureExtractor, ViTForImageClassification

import soundfile as sf
import torch
import numpy as np
import time

# Initialize the transformers and the models
class_names = {
    0: "al qarawiyyin",
    1: "bab mansour el aleuj",
    2: "chaouara tannery",
    3: "hassan tower",
    4: "jamae el fna",
    5: "koutoubia mosque",
    6: "madrasa ben youssef",
    7: "majorel gardens",
    8: "menara"
}

model_name_or_path = "microsoft/DialoGPT-large"
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, padding_side="left", use_fast=False)
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, torch_dtype=torch.float32, trust_remote_code=True)

wav2vec2_processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
wav2vec2_model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")

vit_model = ViTForImageClassification.from_pretrained('ohidaoui/monuments-morocco-v1')
vit_feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224')

# Function to handle text input
def handle_text(text):
    chat_output = chat({"question": text})
    return chat_output["answer"]

# Function to handle image input
def get_class_name(class_idx):
    return class_names[class_idx]

def handle_image(img):
    img = np.array(img)
    inputs = vit_feature_extractor(images=img, return_tensors="pt")
    outputs = vit_model(**inputs)
    predicted_class_idx = torch.argmax(outputs.logits, dim=1).item()
    predicted_class_name = get_class_name(predicted_class_idx)
    chat_output = chat({"question": "what is " + predicted_class_name})
    return chat_output["answer"]

# Function to handle audio input
def handle_audio(audio):
    audio = audio[1]
    input_values = wav2vec2_processor(audio, sampling_rate=16_000, return_tensors="pt").input_values
    input_values = input_values.to(torch.float32)
    logits = wav2vec2_model(input_values).logits
    predicted_ids = torch.argmax(logits, dim=-1)
    transcriptions = wav2vec2_processor.decode(predicted_ids[0])
    chat_output = chat({"question": transcriptions})
    return chat_output["answer"]

# Main function to handle the inputs
def chatbot(history, text=None, img=None, audio=None):
    text_output = handle_text(text) if text is not None else ''
    img_output = handle_image(img) if img is not None else ''
    audio_output = handle_audio(audio) if audio is not None else ''
    outputs = [o for o in [text_output, img_output, audio_output] if o]
    output = "\n".join(outputs)
    
    history[-1][1] = output
    for character in output:
        history[-1][1] += character
        time.sleep(0.05)
        yield history

with gr.Blocks() as demo:
    chat_interface = gr.Chatbot([], elem_id="chatbot", height=750)

    with gr.Row():
        with gr.Column(scale=0.85):
            text_input = gr.Textbox(
                show_label=False,
                placeholder="Input Text here...",
                container=False
            )
        with gr.Column(scale=0.15, min_width=0):
            img_input = gr.Image()
            audio_input = gr.Audio(source="microphone", label="Audio Input")

    text_msg = text_input.submit(chatbot, [chat_interface, text_input], [chat_interface, text_input], queue=False)
    img_msg = img_input.upload(chatbot, [chat_interface, img_input], [chat_interface, img_input], queue=False)
    audio_msg = audio_input.upload(chatbot, [chat_interface, audio_input], [chat_interface, audio_input], queue=False)

demo.queue()
demo.launch(share=True)