badrex commited on
Commit
6f55dac
·
verified ·
1 Parent(s): 099b786

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -21
app.py CHANGED
@@ -2,30 +2,27 @@ import gradio as gr
2
  from transformers import pipeline
3
  import os
4
  import numpy as np
5
- import torch
6
  import spaces
7
 
8
- # load the model
9
- print("Loading model...")
10
- model_id = "badrex/mms-300m-arabic-dialect-identifier"
11
- classifier = pipeline("audio-classification", model=model_id, device='cuda')
12
- print("Model loaded successfully")
13
- print("Model moved to GPU successfully")
14
-
15
- @spaces.GPU
16
- def predict(audio_segment, sr=16000):
17
- return classifier({"sampling_rate": sr, "raw": audio_segment})
18
 
19
  # define dialect mapping
20
  dialect_mapping = {
21
  "MSA": "Modern Standard Arabic (MSA) - العربية الفصحى الحديثة",
22
- "Egyptian": "Egyptian Arabic - اللهجة المصرية العامية",
23
  "Gulf": "Peninsular Arabic - لهجة الجزيرة العربية",
24
  "Levantine": "Levantine Arabic - لهجة بلاد الشام",
25
  "Maghrebi": "Maghrebi Arabic - اللهجة المغاربية"
26
  }
27
 
 
28
  def predict_dialect(audio):
 
 
 
 
 
 
29
  if audio is None:
30
  return {"Error": 1.0}
31
 
@@ -42,8 +39,10 @@ def predict_dialect(audio):
42
 
43
  print(f"Processing audio: sample rate={sr}, shape={audio_array.shape}")
44
 
45
- predictions = predict(sr=sr, audio_segment=audio_array)
 
46
 
 
47
  results = {}
48
  for pred in predictions:
49
  dialect_name = dialect_mapping.get(pred['label'], pred['label'])
@@ -59,17 +58,12 @@ if os.path.exists(examples_dir):
59
  if filename.endswith((".wav", ".mp3", ".ogg")):
60
  examples.append([os.path.join(examples_dir, filename)])
61
  print(f"Found {len(examples)} example files")
62
- else:
63
- print("Examples directory not found")
64
 
65
- # clean description without problematic HTML
66
  description = """
67
  By <a href="https://badrex.github.io/">Badr Alabsi</a> with ❤️🤍💚
68
 
69
- This is a demo for the accurate and robust Transformer-based <a href="https://huggingface.co/badrex/mms-300m-arabic-dialect-identifier">model</a> for Spoken Arabic Dialect Identification (ADI).
70
- From just a short audio clip (5-10 seconds), the model can identify Modern Standard Arabic (MSA) as well as four major regional Arabic varieties: Egyptian Arabic, Peninsular Arabic (Gulf, Yemeni, and Iraqi), Levantine Arabic, and Maghrebi Arabic.
71
-
72
- Simply **upload an audio file** 📤 or **record yourself speaking** 🎙️⏺️ to try out the model!
73
  """
74
 
75
  demo = gr.Interface(
@@ -83,4 +77,5 @@ demo = gr.Interface(
83
  flagging_mode=None
84
  )
85
 
86
- demo.launch(share=True)
 
 
2
  from transformers import pipeline
3
  import os
4
  import numpy as np
 
5
  import spaces
6
 
7
+ print("=== Application Starting ===")
 
 
 
 
 
 
 
 
 
8
 
9
  # define dialect mapping
10
  dialect_mapping = {
11
  "MSA": "Modern Standard Arabic (MSA) - العربية الفصحى الحديثة",
12
+ "Egyptian": "Egyptian Arabic - اللهجة المصرية العامية",
13
  "Gulf": "Peninsular Arabic - لهجة الجزيرة العربية",
14
  "Levantine": "Levantine Arabic - لهجة بلاد الشام",
15
  "Maghrebi": "Maghrebi Arabic - اللهجة المغاربية"
16
  }
17
 
18
+ @spaces.GPU
19
  def predict_dialect(audio):
20
+ # load model inside the GPU function
21
+ print("Loading model on GPU...")
22
+ model_id = "badrex/mms-300m-arabic-dialect-identifier"
23
+ classifier = pipeline("audio-classification", model=model_id) # no device specified
24
+ print("Model loaded successfully")
25
+
26
  if audio is None:
27
  return {"Error": 1.0}
28
 
 
39
 
40
  print(f"Processing audio: sample rate={sr}, shape={audio_array.shape}")
41
 
42
+ # classify the dialect
43
+ predictions = classifier({"sampling_rate": sr, "raw": audio_array})
44
 
45
+ # format results
46
  results = {}
47
  for pred in predictions:
48
  dialect_name = dialect_mapping.get(pred['label'], pred['label'])
 
58
  if filename.endswith((".wav", ".mp3", ".ogg")):
59
  examples.append([os.path.join(examples_dir, filename)])
60
  print(f"Found {len(examples)} example files")
 
 
61
 
 
62
  description = """
63
  By <a href="https://badrex.github.io/">Badr Alabsi</a> with ❤️🤍💚
64
 
65
+ This demo uses a Transformer-based model for Spoken Arabic Dialect Identification.
66
+ Upload an audio file or record yourself speaking to identify the Arabic dialect!
 
 
67
  """
68
 
69
  demo = gr.Interface(
 
77
  flagging_mode=None
78
  )
79
 
80
+ print("=== Launching demo ===")
81
+ demo.launch()