HakimHa commited on
Commit
50d8db6
Β·
1 Parent(s): 9378660

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -4
app.py CHANGED
@@ -1,10 +1,23 @@
1
  import gradio as gr
2
  from PIL import Image
3
- from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer, Wav2Vec2Processor, Wav2Vec2ForCTC
 
4
  import soundfile as sf
5
  import torch
6
  import numpy as np
7
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  model_name_or_path = "microsoft/DialoGPT-large"
9
 
10
  tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, padding_side="left", use_fast=False)
@@ -21,6 +34,10 @@ model = AutoModelForCausalLM.from_pretrained(
21
  wav2vec2_processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
22
  wav2vec2_model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
23
 
 
 
 
 
24
  # Function to handle text input
25
  def handle_text(text):
26
  new_user_input_ids = tokenizer.encode(text + tokenizer.eos_token, return_tensors='pt')
@@ -30,9 +47,28 @@ def handle_text(text):
30
  return chat_output
31
 
32
  # Function to handle image input
 
 
 
 
33
  def handle_image(img):
34
- return "This image seems nice!"
 
35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  # Function to handle audio input
37
  def handle_audio(audio):
38
  # gradio Audio returns a tuple (sample_rate, audio_np_array)
@@ -48,10 +84,11 @@ def handle_audio(audio):
48
 
49
 
50
 
 
51
  def chatbot(text, img, audio):
52
  text_output = handle_text(text) if text is not None else ''
53
- img_output = handle_image(img) if img is not None else ''
54
- audio_output = handle_audio(audio) if audio is not None else ''
55
 
56
  outputs = [o for o in [text_output, img_output, audio_output] if o]
57
  return "\n".join(outputs)
 
1
  import gradio as gr
2
  from PIL import Image
3
+ from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer, Wav2Vec2Processor, Wav2Vec2ForCTC, ViTFeatureExtractor, ViTForImageClassification
4
+
5
  import soundfile as sf
6
  import torch
7
  import numpy as np
8
 
9
+ class_names = {
10
+ 0: "Dog",
11
+ 1: "Cat",
12
+ 2: "Horse",
13
+ 3: "Bird",
14
+ 4: "Elephant",
15
+ 5: "Lion",
16
+ 6: "Fish",
17
+ 7: "Bear",
18
+ 8: "Snake"
19
+ }
20
+
21
  model_name_or_path = "microsoft/DialoGPT-large"
22
 
23
  tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, padding_side="left", use_fast=False)
 
34
  wav2vec2_processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
35
  wav2vec2_model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
36
 
37
+ vit_model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224')
38
+ vit_feature_extractor = ViTFeatureExtractor.from_pretrained('ohidaoui/monuments-morocco-v1')
39
+
40
+
41
  # Function to handle text input
42
  def handle_text(text):
43
  new_user_input_ids = tokenizer.encode(text + tokenizer.eos_token, return_tensors='pt')
 
47
  return chat_output
48
 
49
  # Function to handle image input
50
+ def get_class_name(class_idx):
51
+ return class_names[class_idx]
52
+
53
+
54
  def handle_image(img):
55
+ # Convert PIL image to numpy array
56
+ img = np.array(img)
57
 
58
+ # Apply transformations and prepare image for the model
59
+ inputs = vit_feature_extractor(images=img, return_tensors="pt")
60
+
61
+ # Pass through the Vision Transformer model
62
+ outputs = vit_model(**inputs)
63
+
64
+ # Get the predicted class
65
+ predicted_class_idx = torch.argmax(outputs.logits, dim=1).item()
66
+
67
+
68
+ predicted_class_name = get_class_name(predicted_class_idx)
69
+
70
+ return predicted_class_name
71
+
72
  # Function to handle audio input
73
  def handle_audio(audio):
74
  # gradio Audio returns a tuple (sample_rate, audio_np_array)
 
84
 
85
 
86
 
87
+
88
  def chatbot(text, img, audio):
89
  text_output = handle_text(text) if text is not None else ''
90
+ img_output = handle_text(handle_image(img)) if img is not None else ''
91
+ audio_output = handle_text(handle_audio(audio)) if audio is not None else ''
92
 
93
  outputs = [o for o in [text_output, img_output, audio_output] if o]
94
  return "\n".join(outputs)