HakimHa commited on
Commit
ebba648
Β·
1 Parent(s): 0c14a46

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -57
app.py CHANGED
@@ -1,11 +1,13 @@
1
  import gradio as gr
2
  from PIL import Image
3
- from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer, WhisperProcessor, WhisperForConditionalGeneration, ViTFeatureExtractor, ViTForImageClassification
4
 
5
  import soundfile as sf
6
  import torch
7
  import numpy as np
 
8
 
 
9
  class_names = {
10
  0: "al qarawiyyin",
11
  1: "bab mansour el aleuj",
@@ -16,94 +18,79 @@ class_names = {
16
  6: "madrasa ben youssef",
17
  7: "majorel gardens",
18
  8: "menara"
19
- }
20
 
21
  model_name_or_path = "microsoft/DialoGPT-large"
22
-
23
  tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, padding_side="left", use_fast=False)
24
  tokenizer.pad_token = tokenizer.eos_token
 
25
 
26
- model = AutoModelForCausalLM.from_pretrained(
27
- model_name_or_path,
28
- torch_dtype=torch.float32,
29
- device_map="auto",
30
- trust_remote_code=True,
31
- )
32
-
33
- # Initialize the Wav2Vec2 model and processor
34
- wav2vec2_processor = WhisperProcessor.from_pretrained("openai/whisper-large")
35
- wav2vec2_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-large")
36
- wav2vec2_model.config.forced_decoder_ids = None
37
 
38
  vit_model = ViTForImageClassification.from_pretrained('ohidaoui/monuments-morocco-v1')
39
  vit_feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224')
40
 
41
-
42
  # Function to handle text input
43
  def handle_text(text):
44
- new_user_input_ids = tokenizer.encode(text + tokenizer.eos_token, return_tensors='pt')
45
- bot_input_ids = new_user_input_ids
46
- chat_history_ids = model.generate(bot_input_ids, max_length=1000, pad_token_id=tokenizer.eos_token_id)
47
- chat_output = tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)
48
- return chat_output
49
 
50
  # Function to handle image input
51
  def get_class_name(class_idx):
52
  return class_names[class_idx]
53
 
54
-
55
  def handle_image(img):
56
- # Convert PIL image to numpy array
57
  img = np.array(img)
58
-
59
- # Apply transformations and prepare image for the model
60
  inputs = vit_feature_extractor(images=img, return_tensors="pt")
61
-
62
- # Pass through the Vision Transformer model
63
  outputs = vit_model(**inputs)
64
-
65
- # Get the predicted class
66
  predicted_class_idx = torch.argmax(outputs.logits, dim=1).item()
67
-
68
-
69
  predicted_class_name = get_class_name(predicted_class_idx)
 
 
70
 
71
- return predicted_class_name
72
-
73
  # Function to handle audio input
74
  def handle_audio(audio):
75
- # gradio Audio returns a tuple (sample_rate, audio_np_array)
76
- # we only need the audio data, hence accessing the second element
77
  audio = audio[1]
78
- input_values = wav2vec2_processor(audio, sampling_rate=16000, return_tensors="pt").input_values
79
- # Convert to the expected tensor type
80
  input_values = input_values.to(torch.float32)
81
  logits = wav2vec2_model(input_values).logits
82
  predicted_ids = torch.argmax(logits, dim=-1)
83
  transcriptions = wav2vec2_processor.decode(predicted_ids[0])
84
- return transcriptions
85
-
86
-
87
-
88
 
89
- def chatbot(text, img, audio):
 
90
  text_output = handle_text(text) if text is not None else ''
91
  img_output = handle_image(img) if img is not None else ''
92
  audio_output = handle_audio(audio) if audio is not None else ''
93
-
94
  outputs = [o for o in [text_output, img_output, audio_output] if o]
95
- return "\n".join(outputs)
96
-
97
- iface = gr.Interface(
98
- fn=chatbot,
99
- inputs=[
100
- gr.inputs.Textbox(lines=2, placeholder="Input Text here..."),
101
- gr.inputs.Image(label="Upload Image"),
102
- gr.inputs.Audio(source="microphone", label="Audio Input"),
103
- ],
104
- outputs=gr.outputs.Textbox(label="Output"),
105
- title="Multimodal Chatbot",
106
- description="This chatbot can handle text, image, and audio inputs. Try it out!",
107
- )
108
-
109
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
  from PIL import Image
3
+ from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer, Wav2Vec2Processor, Wav2Vec2ForCTC, ViTFeatureExtractor, ViTForImageClassification
4
 
5
  import soundfile as sf
6
  import torch
7
  import numpy as np
8
+ import time
9
 
10
+ # Initialize the transformers and the models
11
  class_names = {
12
  0: "al qarawiyyin",
13
  1: "bab mansour el aleuj",
 
18
  6: "madrasa ben youssef",
19
  7: "majorel gardens",
20
  8: "menara"
21
+ }
22
 
23
  model_name_or_path = "microsoft/DialoGPT-large"
 
24
  tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, padding_side="left", use_fast=False)
25
  tokenizer.pad_token = tokenizer.eos_token
26
+ model = AutoModelForCausalLM.from_pretrained(model_name_or_path, torch_dtype=torch.float32, trust_remote_code=True)
27
 
28
+ wav2vec2_processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
29
+ wav2vec2_model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
 
 
 
 
 
 
 
 
 
30
 
31
  vit_model = ViTForImageClassification.from_pretrained('ohidaoui/monuments-morocco-v1')
32
  vit_feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224')
33
 
 
34
  # Function to handle text input
35
  def handle_text(text):
36
+ chat_output = chat({"question": text})
37
+ return chat_output["answer"]
 
 
 
38
 
39
  # Function to handle image input
40
  def get_class_name(class_idx):
41
  return class_names[class_idx]
42
 
 
43
  def handle_image(img):
 
44
  img = np.array(img)
 
 
45
  inputs = vit_feature_extractor(images=img, return_tensors="pt")
 
 
46
  outputs = vit_model(**inputs)
 
 
47
  predicted_class_idx = torch.argmax(outputs.logits, dim=1).item()
 
 
48
  predicted_class_name = get_class_name(predicted_class_idx)
49
+ chat_output = chat({"question": "what is " + predicted_class_name})
50
+ return chat_output["answer"]
51
 
 
 
52
  # Function to handle audio input
53
  def handle_audio(audio):
 
 
54
  audio = audio[1]
55
+ input_values = wav2vec2_processor(audio, sampling_rate=16_000, return_tensors="pt").input_values
 
56
  input_values = input_values.to(torch.float32)
57
  logits = wav2vec2_model(input_values).logits
58
  predicted_ids = torch.argmax(logits, dim=-1)
59
  transcriptions = wav2vec2_processor.decode(predicted_ids[0])
60
+ chat_output = chat({"question": transcriptions})
61
+ return chat_output["answer"]
 
 
62
 
63
+ # Main function to handle the inputs
64
+ def chatbot(history, text=None, img=None, audio=None):
65
  text_output = handle_text(text) if text is not None else ''
66
  img_output = handle_image(img) if img is not None else ''
67
  audio_output = handle_audio(audio) if audio is not None else ''
 
68
  outputs = [o for o in [text_output, img_output, audio_output] if o]
69
+ output = "\n".join(outputs)
70
+
71
+ history[-1][1] = output
72
+ for character in output:
73
+ history[-1][1] += character
74
+ time.sleep(0.05)
75
+ yield history
76
+
77
+ with gr.Blocks() as demo:
78
+ chat_interface = gr.Chatbot([], elem_id="chatbot", height=750)
79
+
80
+ with gr.Row():
81
+ with gr.Column(scale=0.85):
82
+ text_input = gr.Textbox(
83
+ show_label=False,
84
+ placeholder="Input Text here...",
85
+ container=False
86
+ )
87
+ with gr.Column(scale=0.15, min_width=0):
88
+ img_input = gr.Image()
89
+ audio_input = gr.Audio(source="microphone", label="Audio Input")
90
+
91
+ text_msg = text_input.submit(chatbot, [chat_interface, text_input], [chat_interface, text_input], queue=False)
92
+ img_msg = img_input.upload(chatbot, [chat_interface, img_input], [chat_interface, img_input], queue=False)
93
+ audio_msg = audio_input.upload(chatbot, [chat_interface, audio_input], [chat_interface, audio_input], queue=False)
94
+
95
+ demo.queue()
96
+ demo.launch(share=True)