badrex commited on
Commit
b78ce3a
·
verified ·
1 Parent(s): 2134b08

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -41
app.py CHANGED
@@ -1,52 +1,49 @@
1
- import gradio as gr
2
- from transformers import pipeline
3
- import os
4
- import numpy as np
5
  import spaces
6
 
7
- print("=== Application Starting ===")
 
 
 
 
 
 
 
 
 
 
8
 
9
  # define dialect mapping
 
10
  dialect_mapping = {
11
  "MSA": "Modern Standard Arabic (MSA) - العربية الفصحى الحديثة",
12
- "Egyptian": "Egyptian Arabic - اللهجة المصرية العامية",
13
- "Gulf": "Peninsular Arabic - لهجة الجزيرة العربية",
14
- "Levantine": "Levantine Arabic - لهجة بلاد الشام",
15
- "Maghrebi": "Maghrebi Arabic - اللهجة المغاربية"
16
- }
17
-
18
- @spaces.GPU
19
- def predict_dialect(audio):
20
- # load model inside the GPU function
21
- print("Loading model on GPU...")
22
- model_id = "badrex/mms-300m-arabic-dialect-identifier"
23
- classifier = pipeline("audio-classification", model=model_id) # no device specified
24
- print("Model loaded successfully")
25
-
26
  if audio is None:
27
  return {"Error": 1.0}
28
 
 
29
  sr, audio_array = audio
30
 
 
31
  if len(audio_array.shape) > 1:
32
  audio_array = audio_array.mean(axis=1)
33
 
 
34
  if audio_array.dtype != np.float32:
 
35
  if audio_array.dtype == np.int16:
36
  audio_array = audio_array.astype(np.float32) / 32768.0
37
  else:
38
- audio_array = audio_array.astype(np.float32)
39
 
40
  print(f"Processing audio: sample rate={sr}, shape={audio_array.shape}")
41
 
42
- # classify the dialect
43
- predictions = classifier({"sampling_rate": sr, "raw": audio_array})
44
 
45
- # format results
46
  results = {}
47
  for pred in predictions:
48
  dialect_name = dialect_mapping.get(pred['label'], pred['label'])
49
- results[dialect_name] = float(pred['score'])
50
 
51
  return results
52
 
@@ -57,25 +54,16 @@ if os.path.exists(examples_dir):
57
  for filename in os.listdir(examples_dir):
58
  if filename.endswith((".wav", ".mp3", ".ogg")):
59
  examples.append([os.path.join(examples_dir, filename)])
 
60
  print(f"Found {len(examples)} example files")
 
 
61
 
 
62
  description = """
63
  By <a href="https://badrex.github.io/">Badr Alabsi</a> with ❤️🤍💚
64
 
65
- This demo uses a Transformer-based model for Spoken Arabic Dialect Identification.
66
- Upload an audio file or record yourself speaking to identify the Arabic dialect!
67
- """
68
-
69
- demo = gr.Interface(
70
- fn=predict_dialect,
71
- inputs=gr.Audio(),
72
- outputs=gr.Label(num_top_classes=5, label="Predicted Dialect"),
73
- title="Tamyïz 🍉 Arabic Dialect Identification in Speech",
74
- description=description,
75
- examples=examples if examples else None,
76
- cache_examples=False,
77
- flagging_mode=None
78
- )
79
-
80
- print("=== Launching demo ===")
81
- demo.launch()
 
1
+ import torch
 
 
 
2
  import spaces
3
 
4
+ # load the model
5
+ print("Loading model...")
6
+ model_id = "badrex/mms-300m-arabic-dialect-identifier"
7
+ classifier = pipeline("audio-classification", model=model_id, device='cuda')
8
+ print("Model loaded successfully")
9
+ print("Model moved to GPU successfully")
10
+
11
+
12
+ @spaces.GPU
13
+ def predict(audio_segment, sr=16000):
14
+ return classifier({"sampling_rate": sr, "raw": audio_segment})
15
 
16
  # define dialect mapping
17
+
18
  dialect_mapping = {
19
  "MSA": "Modern Standard Arabic (MSA) - العربية الفصحى الحديثة",
20
+ "Egyptian": "Egyptian Arabic - اللهجة المصرية العامية",
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  if audio is None:
22
  return {"Error": 1.0}
23
 
24
+
25
  sr, audio_array = audio
26
 
27
+
28
  if len(audio_array.shape) > 1:
29
  audio_array = audio_array.mean(axis=1)
30
 
31
+
32
  if audio_array.dtype != np.float32:
33
+
34
  if audio_array.dtype == np.int16:
35
  audio_array = audio_array.astype(np.float32) / 32768.0
36
  else:
 
37
 
38
  print(f"Processing audio: sample rate={sr}, shape={audio_array.shape}")
39
 
40
+
41
+ predictions = predict(sr=sr, audio_segment=audio_array)
42
 
43
+
44
  results = {}
45
  for pred in predictions:
46
  dialect_name = dialect_mapping.get(pred['label'], pred['label'])
 
47
 
48
  return results
49
 
 
54
  for filename in os.listdir(examples_dir):
55
  if filename.endswith((".wav", ".mp3", ".ogg")):
56
  examples.append([os.path.join(examples_dir, filename)])
57
+
58
  print(f"Found {len(examples)} example files")
59
+ else:
60
+ print("Examples directory not found")
61
 
62
+ # clean description without problematic HTML
63
  description = """
64
  By <a href="https://badrex.github.io/">Badr Alabsi</a> with ❤️🤍💚
65
 
66
+ This is a demo for the accurate and robust Transformer-based <a href="https://huggingface.co/badrex/mms-300m-arabic-dialect-identifier">model</a> for Spoken Arabic Dialect Identification (ADI).
67
+ From just a short audio clip (5-10 seconds), the model can identify Modern Standard Arabic (MSA) as well as four major regional Arabic varieties: Egyptian Arabic, Peninsular Arabic (Gulf, Yemeni, and Iraqi), Levantine Arabic, and Maghrebi Arabic.
68
+
69
+ Simply **upload an audio file** 📤 or **record yourself speaking** 🎙️⏺️ to try out the model!