File size: 3,109 Bytes
d46aa73
 
 
 
b78ce3a
17f8795
07a50af
b78ce3a
 
 
 
 
 
 
 
 
 
a6158c1
099b786
07a50af
d06382d
b78ce3a
d46aa73
 
 
 
 
 
87966ec
b04a244
87966ec
 
 
 
099b786
fa9bb6e
 
 
 
 
d46aa73
87966ec
 
 
b78ce3a
87966ec
 
 
 
d46aa73
87966ec
 
23545c8
099b786
87966ec
23545c8
 
 
 
87966ec
 
b78ce3a
 
23545c8
b78ce3a
099b786
d46aa73
406c152
b78ce3a
 
 
86ceded
 
 
 
 
 
 
e603ca9
86ceded
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import gradio as gr
from transformers import pipeline
import os
import numpy as np
import torch
import spaces

# load the model
print("Loading model...")
model_id = "badrex/mms-300m-arabic-dialect-identifier"
classifier = pipeline("audio-classification", model=model_id, device='cuda')
print("Model loaded successfully")
print("Model moved to GPU successfully")

@spaces.GPU
def predict(audio_segment, sr=16000):
    return classifier({"sampling_rate": sr, "raw": audio_segment})

# define dialect mapping
dialect_mapping = {
    "MSA": "Modern Standard Arabic (MSA) - العربية الفصحى الحديثة",
    "Egyptian": "Egyptian Arabic -  اللهجة المصرية العامية",
    "Gulf": "Peninsular Arabic - لهجة الجزيرة العربية",
    "Levantine": "Levantine Arabic - لهجة بلاد الشام",
    "Maghrebi": "Maghrebi Arabic - اللهجة المغاربية"
}

def predict_dialect(audio):
    if audio is None:
        return {"Error": 1.0}
    
    sr, audio_array = audio
    
    if len(audio_array.shape) > 1:
        audio_array = audio_array.mean(axis=1)

    if audio_array.dtype != np.float32:
        if audio_array.dtype == np.int16:
            audio_array = audio_array.astype(np.float32) / 32768.0
        else:
            audio_array = audio_array.astype(np.float32)
    
    print(f"Processing audio: sample rate={sr}, shape={audio_array.shape}")
    
    predictions = predict(sr=sr, audio_segment=audio_array)
    
    results = {}
    for pred in predictions:
        dialect_name = dialect_mapping.get(pred['label'], pred['label'])
        results[dialect_name] = float(pred['score'])
    
    return results

# prepare examples
examples = []
examples_dir = "examples"
if os.path.exists(examples_dir):
    for filename in os.listdir(examples_dir):
        if filename.endswith((".wav", ".mp3", ".ogg")):
            examples.append([os.path.join(examples_dir, filename)])
    print(f"Found {len(examples)} example files")
else:
    print("Examples directory not found")

# clean description without problematic HTML
description = """
Developed with ❤️🤍💚 by <a href="https://badrex.github.io/">Badr al-Absi</a> 

This is a demo for the accurate and robust Transformer-based <a href="https://huggingface.co/badrex/mms-300m-arabic-dialect-identifier">model</a> for Spoken Arabic Dialect Identification (ADI). 
From just a short audio clip (5-10 seconds), the model can identify Modern Standard Arabic (MSA) as well as four major regional Arabic varieties: Egyptian Arabic, Peninsular Arabic (Gulf, Yemeni, and Iraqi), Levantine Arabic, and Maghrebi Arabic.

Simply **upload an audio file** 📤 or **record yourself speaking** 🎙️⏺️ to try out the model!
"""

demo = gr.Interface(
    fn=predict_dialect,
    inputs=gr.Audio(),
    outputs=gr.Label(num_top_classes=5, label="Predicted Dialect"),
    title="<div>Tamyïz 🍉 <br> Arabic Dialect Identification in Speech</div>",
    description=description,
    examples=examples if examples else None,
    cache_examples=False,
    flagging_mode=None
)

demo.launch(share=True)