dioarafl commited on
Commit
bf3e8dc
·
verified ·
1 Parent(s): c5a3421

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -12
app.py CHANGED
@@ -2,11 +2,13 @@ import cv2
2
  import gradio as gr
3
  import tempfile
4
  import torch
 
5
  from torchvision.models.detection import fasterrcnn_resnet50_fpn
6
  import torchvision.transforms as transforms
7
  from PIL import Image
8
  import numpy as np
9
  import soundfile as sf
 
10
 
11
  class FasterRCNNDetector:
12
  def __init__(self):
@@ -49,8 +51,8 @@ class FasterRCNNDetector:
49
  class JarvisModels:
50
  def __init__(self):
51
  self.client1 = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")
52
- self.model = deepspeech.Model("deepspeech-0.9.3-models.pbmm")
53
- self.model.setBeamWidth(500)
54
 
55
  async def generate_response(self, prompt):
56
  generate_kwargs = dict(
@@ -74,20 +76,18 @@ class JarvisModels:
74
  communicate.save(tmp_path)
75
  return tmp_path
76
 
77
- def transcribe_audio(audio_file):
78
- model = JarvisModels().model
79
- audio, sample_rate = sf.read(audio_file)
80
- return model.stt(audio)
81
-
82
- def generate_response(frame):
83
- jarvis = JarvisModels()
84
- response_model = await jarvis.generate_response("Hello, I see some interesting objects!")
85
- return response_model
86
 
87
  detector = FasterRCNNDetector()
88
 
89
  iface = gr.Interface(
90
- fn=[detector.detect_objects, transcribe_audio],
91
  inputs=gr.inputs.Video(label="Webcam", parameters={"fps": 30}),
92
  outputs=[gr.outputs.Image(), "text"],
93
  title="Vision and Speech Interface",
 
2
  import gradio as gr
3
  import tempfile
4
  import torch
5
+ import torchaudio
6
  from torchvision.models.detection import fasterrcnn_resnet50_fpn
7
  import torchvision.transforms as transforms
8
  from PIL import Image
9
  import numpy as np
10
  import soundfile as sf
11
+ from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
12
 
13
  class FasterRCNNDetector:
14
  def __init__(self):
 
51
  class JarvisModels:
52
  def __init__(self):
53
  self.client1 = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")
54
+ self.processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
55
+ self.model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
56
 
57
  async def generate_response(self, prompt):
58
  generate_kwargs = dict(
 
76
  communicate.save(tmp_path)
77
  return tmp_path
78
 
79
+ async def transcribe_audio(self, audio_file):
80
+ input_audio, _ = torchaudio.load(audio_file)
81
+ input_values = self.processor(input_audio, return_tensors="pt").input_values
82
+ logits = self.model(input_values).logits
83
+ predicted_ids = torch.argmax(logits, dim=-1)
84
+ transcription = self.processor.batch_decode(predicted_ids)
85
+ return transcription[0]
 
 
86
 
87
  detector = FasterRCNNDetector()
88
 
89
  iface = gr.Interface(
90
+ fn=[detector.detect_objects, JarvisModels().transcribe_audio],
91
  inputs=gr.inputs.Video(label="Webcam", parameters={"fps": 30}),
92
  outputs=[gr.outputs.Image(), "text"],
93
  title="Vision and Speech Interface",