Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -1,3 +1,7 @@
|
|
|
|
|
|
|
|
|
|
1 |
import torch
|
2 |
import spaces
|
3 |
|
@@ -8,42 +12,42 @@ classifier = pipeline("audio-classification", model=model_id, device='cuda')
|
|
8 |
print("Model loaded successfully")
|
9 |
print("Model moved to GPU successfully")
|
10 |
|
11 |
-
|
12 |
@spaces.GPU
|
13 |
def predict(audio_segment, sr=16000):
|
14 |
return classifier({"sampling_rate": sr, "raw": audio_segment})
|
15 |
|
16 |
# define dialect mapping
|
17 |
-
|
18 |
dialect_mapping = {
|
19 |
"MSA": "Modern Standard Arabic (MSA) - العربية الفصحى الحديثة",
|
20 |
"Egyptian": "Egyptian Arabic - اللهجة المصرية العامية",
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
if audio is None:
|
22 |
return {"Error": 1.0}
|
23 |
|
24 |
-
|
25 |
sr, audio_array = audio
|
26 |
|
27 |
-
|
28 |
if len(audio_array.shape) > 1:
|
29 |
audio_array = audio_array.mean(axis=1)
|
30 |
|
31 |
-
|
32 |
if audio_array.dtype != np.float32:
|
33 |
-
|
34 |
if audio_array.dtype == np.int16:
|
35 |
audio_array = audio_array.astype(np.float32) / 32768.0
|
36 |
else:
|
|
|
37 |
|
38 |
print(f"Processing audio: sample rate={sr}, shape={audio_array.shape}")
|
39 |
|
40 |
-
|
41 |
predictions = predict(sr=sr, audio_segment=audio_array)
|
42 |
|
43 |
-
|
44 |
results = {}
|
45 |
for pred in predictions:
|
46 |
dialect_name = dialect_mapping.get(pred['label'], pred['label'])
|
|
|
47 |
|
48 |
return results
|
49 |
|
@@ -54,14 +58,13 @@ if os.path.exists(examples_dir):
|
|
54 |
for filename in os.listdir(examples_dir):
|
55 |
if filename.endswith((".wav", ".mp3", ".ogg")):
|
56 |
examples.append([os.path.join(examples_dir, filename)])
|
57 |
-
|
58 |
print(f"Found {len(examples)} example files")
|
59 |
else:
|
60 |
print("Examples directory not found")
|
61 |
|
62 |
# clean description without problematic HTML
|
63 |
description = """
|
64 |
-
|
65 |
|
66 |
This is a demo for the accurate and robust Transformer-based <a href="https://huggingface.co/badrex/mms-300m-arabic-dialect-identifier">model</a> for Spoken Arabic Dialect Identification (ADI).
|
67 |
From just a short audio clip (5-10 seconds), the model can identify Modern Standard Arabic (MSA) as well as four major regional Arabic varieties: Egyptian Arabic, Peninsular Arabic (Gulf, Yemeni, and Iraqi), Levantine Arabic, and Maghrebi Arabic.
|
@@ -69,7 +72,6 @@ From just a short audio clip (5-10 seconds), the model can identify Modern Stand
|
|
69 |
Simply **upload an audio file** 📤 or **record yourself speaking** 🎙️⏺️ to try out the model!
|
70 |
"""
|
71 |
|
72 |
-
|
73 |
demo = gr.Interface(
|
74 |
fn=predict_dialect,
|
75 |
inputs=gr.Audio(),
|
@@ -78,9 +80,7 @@ demo = gr.Interface(
|
|
78 |
description=description,
|
79 |
examples=examples if examples else None,
|
80 |
cache_examples=False,
|
81 |
-
|
82 |
flagging_mode=None
|
83 |
)
|
84 |
|
85 |
-
|
86 |
demo.launch(share=True)
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from transformers import pipeline
|
3 |
+
import os
|
4 |
+
import numpy as np
|
5 |
import torch
|
6 |
import spaces
|
7 |
|
|
|
12 |
print("Model loaded successfully")
|
13 |
print("Model moved to GPU successfully")
|
14 |
|
|
|
15 |
@spaces.GPU
|
16 |
def predict(audio_segment, sr=16000):
|
17 |
return classifier({"sampling_rate": sr, "raw": audio_segment})
|
18 |
|
19 |
# define dialect mapping
|
|
|
20 |
dialect_mapping = {
|
21 |
"MSA": "Modern Standard Arabic (MSA) - العربية الفصحى الحديثة",
|
22 |
"Egyptian": "Egyptian Arabic - اللهجة المصرية العامية",
|
23 |
+
"Gulf": "Peninsular Arabic - لهجة الجزيرة العربية",
|
24 |
+
"Levantine": "Levantine Arabic - لهجة بلاد الشام",
|
25 |
+
"Maghrebi": "Maghrebi Arabic - اللهجة المغاربية"
|
26 |
+
}
|
27 |
+
|
28 |
+
def predict_dialect(audio):
|
29 |
if audio is None:
|
30 |
return {"Error": 1.0}
|
31 |
|
|
|
32 |
sr, audio_array = audio
|
33 |
|
|
|
34 |
if len(audio_array.shape) > 1:
|
35 |
audio_array = audio_array.mean(axis=1)
|
36 |
|
|
|
37 |
if audio_array.dtype != np.float32:
|
|
|
38 |
if audio_array.dtype == np.int16:
|
39 |
audio_array = audio_array.astype(np.float32) / 32768.0
|
40 |
else:
|
41 |
+
audio_array = audio_array.astype(np.float32)
|
42 |
|
43 |
print(f"Processing audio: sample rate={sr}, shape={audio_array.shape}")
|
44 |
|
|
|
45 |
predictions = predict(sr=sr, audio_segment=audio_array)
|
46 |
|
|
|
47 |
results = {}
|
48 |
for pred in predictions:
|
49 |
dialect_name = dialect_mapping.get(pred['label'], pred['label'])
|
50 |
+
results[dialect_name] = float(pred['score'])
|
51 |
|
52 |
return results
|
53 |
|
|
|
58 |
for filename in os.listdir(examples_dir):
|
59 |
if filename.endswith((".wav", ".mp3", ".ogg")):
|
60 |
examples.append([os.path.join(examples_dir, filename)])
|
|
|
61 |
print(f"Found {len(examples)} example files")
|
62 |
else:
|
63 |
print("Examples directory not found")
|
64 |
|
65 |
# clean description without problematic HTML
|
66 |
description = """
|
67 |
+
Developed with ❤️🤍💚 by <a href="https://badrex.github.io/">Badr al-Absi</a>
|
68 |
|
69 |
This is a demo for the accurate and robust Transformer-based <a href="https://huggingface.co/badrex/mms-300m-arabic-dialect-identifier">model</a> for Spoken Arabic Dialect Identification (ADI).
|
70 |
From just a short audio clip (5-10 seconds), the model can identify Modern Standard Arabic (MSA) as well as four major regional Arabic varieties: Egyptian Arabic, Peninsular Arabic (Gulf, Yemeni, and Iraqi), Levantine Arabic, and Maghrebi Arabic.
|
|
|
72 |
Simply **upload an audio file** 📤 or **record yourself speaking** 🎙️⏺️ to try out the model!
|
73 |
"""
|
74 |
|
|
|
75 |
demo = gr.Interface(
|
76 |
fn=predict_dialect,
|
77 |
inputs=gr.Audio(),
|
|
|
80 |
description=description,
|
81 |
examples=examples if examples else None,
|
82 |
cache_examples=False,
|
|
|
83 |
flagging_mode=None
|
84 |
)
|
85 |
|
|
|
86 |
demo.launch(share=True)
|