badrex commited on
Commit
d46aa73
·
verified ·
1 Parent(s): 9f08751

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -13
app.py CHANGED
@@ -1,3 +1,7 @@
 
 
 
 
1
  import torch
2
  import spaces
3
 
@@ -8,42 +12,42 @@ classifier = pipeline("audio-classification", model=model_id, device='cuda')
8
  print("Model loaded successfully")
9
  print("Model moved to GPU successfully")
10
 
11
-
12
  @spaces.GPU
13
  def predict(audio_segment, sr=16000):
14
  return classifier({"sampling_rate": sr, "raw": audio_segment})
15
 
16
  # define dialect mapping
17
-
18
  dialect_mapping = {
19
  "MSA": "Modern Standard Arabic (MSA) - العربية الفصحى الحديثة",
20
  "Egyptian": "Egyptian Arabic - اللهجة المصرية العامية",
 
 
 
 
 
 
21
  if audio is None:
22
  return {"Error": 1.0}
23
 
24
-
25
  sr, audio_array = audio
26
 
27
-
28
  if len(audio_array.shape) > 1:
29
  audio_array = audio_array.mean(axis=1)
30
 
31
-
32
  if audio_array.dtype != np.float32:
33
-
34
  if audio_array.dtype == np.int16:
35
  audio_array = audio_array.astype(np.float32) / 32768.0
36
  else:
 
37
 
38
  print(f"Processing audio: sample rate={sr}, shape={audio_array.shape}")
39
 
40
-
41
  predictions = predict(sr=sr, audio_segment=audio_array)
42
 
43
-
44
  results = {}
45
  for pred in predictions:
46
  dialect_name = dialect_mapping.get(pred['label'], pred['label'])
 
47
 
48
  return results
49
 
@@ -54,14 +58,13 @@ if os.path.exists(examples_dir):
54
  for filename in os.listdir(examples_dir):
55
  if filename.endswith((".wav", ".mp3", ".ogg")):
56
  examples.append([os.path.join(examples_dir, filename)])
57
-
58
  print(f"Found {len(examples)} example files")
59
  else:
60
  print("Examples directory not found")
61
 
62
  # clean description without problematic HTML
63
  description = """
64
- By <a href="https://badrex.github.io/">Badr Alabsi</a> with ❤️🤍💚
65
 
66
  This is a demo for the accurate and robust Transformer-based <a href="https://huggingface.co/badrex/mms-300m-arabic-dialect-identifier">model</a> for Spoken Arabic Dialect Identification (ADI).
67
  From just a short audio clip (5-10 seconds), the model can identify Modern Standard Arabic (MSA) as well as four major regional Arabic varieties: Egyptian Arabic, Peninsular Arabic (Gulf, Yemeni, and Iraqi), Levantine Arabic, and Maghrebi Arabic.
@@ -69,7 +72,6 @@ From just a short audio clip (5-10 seconds), the model can identify Modern Stand
69
  Simply **upload an audio file** 📤 or **record yourself speaking** 🎙️⏺️ to try out the model!
70
  """
71
 
72
-
73
  demo = gr.Interface(
74
  fn=predict_dialect,
75
  inputs=gr.Audio(),
@@ -78,9 +80,7 @@ demo = gr.Interface(
78
  description=description,
79
  examples=examples if examples else None,
80
  cache_examples=False,
81
-
82
  flagging_mode=None
83
  )
84
 
85
-
86
  demo.launch(share=True)
 
1
+ import gradio as gr
2
+ from transformers import pipeline
3
+ import os
4
+ import numpy as np
5
  import torch
6
  import spaces
7
 
 
12
  print("Model loaded successfully")
13
  print("Model moved to GPU successfully")
14
 
 
15
  @spaces.GPU
16
  def predict(audio_segment, sr=16000):
17
  return classifier({"sampling_rate": sr, "raw": audio_segment})
18
 
19
  # define dialect mapping
 
20
  dialect_mapping = {
21
  "MSA": "Modern Standard Arabic (MSA) - العربية الفصحى الحديثة",
22
  "Egyptian": "Egyptian Arabic - اللهجة المصرية العامية",
23
+ "Gulf": "Peninsular Arabic - لهجة الجزيرة العربية",
24
+ "Levantine": "Levantine Arabic - لهجة بلاد الشام",
25
+ "Maghrebi": "Maghrebi Arabic - اللهجة المغاربية"
26
+ }
27
+
28
+ def predict_dialect(audio):
29
  if audio is None:
30
  return {"Error": 1.0}
31
 
 
32
  sr, audio_array = audio
33
 
 
34
  if len(audio_array.shape) > 1:
35
  audio_array = audio_array.mean(axis=1)
36
 
 
37
  if audio_array.dtype != np.float32:
 
38
  if audio_array.dtype == np.int16:
39
  audio_array = audio_array.astype(np.float32) / 32768.0
40
  else:
41
+ audio_array = audio_array.astype(np.float32)
42
 
43
  print(f"Processing audio: sample rate={sr}, shape={audio_array.shape}")
44
 
 
45
  predictions = predict(sr=sr, audio_segment=audio_array)
46
 
 
47
  results = {}
48
  for pred in predictions:
49
  dialect_name = dialect_mapping.get(pred['label'], pred['label'])
50
+ results[dialect_name] = float(pred['score'])
51
 
52
  return results
53
 
 
58
  for filename in os.listdir(examples_dir):
59
  if filename.endswith((".wav", ".mp3", ".ogg")):
60
  examples.append([os.path.join(examples_dir, filename)])
 
61
  print(f"Found {len(examples)} example files")
62
  else:
63
  print("Examples directory not found")
64
 
65
  # clean description without problematic HTML
66
  description = """
67
+ Developed with ❤️🤍💚 by <a href="https://badrex.github.io/">Badr al-Absi</a>
68
 
69
  This is a demo for the accurate and robust Transformer-based <a href="https://huggingface.co/badrex/mms-300m-arabic-dialect-identifier">model</a> for Spoken Arabic Dialect Identification (ADI).
70
  From just a short audio clip (5-10 seconds), the model can identify Modern Standard Arabic (MSA) as well as four major regional Arabic varieties: Egyptian Arabic, Peninsular Arabic (Gulf, Yemeni, and Iraqi), Levantine Arabic, and Maghrebi Arabic.
 
72
  Simply **upload an audio file** 📤 or **record yourself speaking** 🎙️⏺️ to try out the model!
73
  """
74
 
 
75
  demo = gr.Interface(
76
  fn=predict_dialect,
77
  inputs=gr.Audio(),
 
80
  description=description,
81
  examples=examples if examples else None,
82
  cache_examples=False,
 
83
  flagging_mode=None
84
  )
85
 
 
86
  demo.launch(share=True)