File size: 9,234 Bytes
e7d0ead
6cc8631
 
2785b50
 
 
bc32c3a
2785b50
 
 
 
 
 
 
 
 
 
 
 
85e680f
 
2785b50
85e680f
2785b50
f2ecb6e
 
4a98e32
 
f2ecb6e
2785b50
f2ecb6e
 
4a98e32
f2ecb6e
3b0e26c
 
 
8dc5e68
3b0e26c
 
 
 
9e3ffca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4a98e32
b6cc6ac
 
 
 
3f0375d
b6cc6ac
 
 
 
 
 
 
3f0375d
b6cc6ac
 
 
 
 
 
 
 
 
 
 
 
 
4a98e32
b6cc6ac
 
 
4a98e32
 
 
 
b6cc6ac
 
 
 
 
 
 
4a98e32
 
 
 
b6cc6ac
 
980dcf2
2785b50
f117ae1
 
 
2785b50
f117ae1
2785b50
f117ae1
2785b50
f117ae1
 
 
 
 
9078685
4a98e32
f117ae1
 
9078685
4a98e32
b6cc6ac
a19085f
b6cc6ac
4a98e32
3b0e26c
 
 
 
9078685
f117ae1
3b0e26c
d04bf8d
 
2785b50
 
 
 
 
980dcf2
f2ecb6e
2785b50
 
f2ecb6e
 
 
 
 
 
3b0e26c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2785b50
1a38424
f2ecb6e
9078685
 
f2ecb6e
 
9078685
 
4a98e32
 
 
 
3b0e26c
f2ecb6e
 
4a98e32
 
3b0e26c
d27a60f
980dcf2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
import gradio as gr
import librosa
import torch
from transformers import (
    Wav2Vec2ForCTC, Wav2Vec2Processor, 
    MarianMTModel, MarianTokenizer, 
    BertForSequenceClassification, AutoModel, AutoTokenizer,AutoModelForSequenceClassification
)

# Detect device
device = "cuda" if torch.cuda.is_available() else "cpu"

### 🔹 Load Models & Tokenizers Once ###
# Wav2Vec2 for Darija transcription
wav2vec_model_name = "boumehdi/wav2vec2-large-xlsr-moroccan-darija"
processor = Wav2Vec2Processor.from_pretrained(wav2vec_model_name)
wav2vec_model = Wav2Vec2ForCTC.from_pretrained(wav2vec_model_name).to(device)

# MarianMT for translation (Arabic → English)
translation_model_name = "Helsinki-NLP/opus-mt-ar-en"
translation_tokenizer = MarianTokenizer.from_pretrained(translation_model_name)
translation_model = MarianMTModel.from_pretrained(translation_model_name).to(device)

# AraBERT for Darija topic classification
arabert_model_name = "aubmindlab/bert-base-arabert"
arabert_tokenizer = AutoTokenizer.from_pretrained(arabert_model_name)
arabert_model = BertForSequenceClassification.from_pretrained(arabert_model_name, num_labels=2).to(device)


# BERT for English topic classification
bert_model_name = "bert-base-uncased"
bert_tokenizer = AutoTokenizer.from_pretrained(bert_model_name)
bert_model = BertForSequenceClassification.from_pretrained(bert_model_name, num_labels=3).to(device)  

# Charger le modèle et le tokenizer Darija
sentiment_model_name = "BenhamdaneNawfal/sentiment-analysis-darija"
sentiment_tokenizer = AutoTokenizer.from_pretrained(sentiment_model_name)
sentiment_model = AutoModelForSequenceClassification.from_pretrained(sentiment_model_name,num_labels=3,ignore_mismatched_sizes=True).to("cuda" if torch.cuda.is_available() else "cpu")

# Labels du modèle (à modifier selon le modèle utilisé)
sentiment_labels = ["Négatif", "Neutre", "Positif"]

# Libellés en Darija (Arabe et Latin)
darija_topic_labels = [
    "مشكيل ف الشبكة (Mochkil f réseau)",        # Problème de réseau
    "مشكيل ف الانترنت (Mochkil f internet)",      # Problème d'Internet
    "مشكيل ف الفاتورة (Mochkil f l'factura)",     # Problème de facturation et paiement
    "مشكيل ف التعبئة (Mochkil f l'recharge)",     # Problème de recharge et forfaits
    "مشكيل ف التطبيق (Mochkil f l'application)",  # Problème avec l’application (Orange et Moi...)
    "مشكيل ف بطاقة SIM (Mochkil f carte SIM)",    # Problème avec la carte SIM
    "مساعدة تقنية (Mosa3ada technique)",         # Assistance technique
    "العروض والتخفيضات (Offres w promotions)",   # Offres et promotions
    "طلب معلومات (Talab l'ma3loumat)",           # Demande d'information
    "شكاية (Chikaya)",                            # Réclamation
    "حاجة أخرى (Chi haja okhra)"                 # Autre
]

# Libellés en Anglais
english_topic_labels = [
    "Network Issue",
    "Internet Issue",
    "Billing & Payment Issue",
    "Recharge & Plans",
    "App Issue",
    "SIM Card Issue",
    "Technical Support",
    "Offers & Promotions",
    "General Inquiry",
    "Complaint",
    "Other"
]

# New Function to Classify Topics by Keywords
def classify_topic_by_keywords(text, language='ar'):
    # Arabic keywords for each topic
    arabic_keywords = {
        "Customer Service": ["خدمة", "استفسار", "مساعدة", "دعم", "سؤال", "استفسار"],
        "résiliation Service": ["نوقف", "تجديد", "خصم", "عرض", "نحي"],
        "Billing Issue": ["فاتورة", "دفع", "مشكلة", "خطأ", "مبلغ"],
        "Other": ["شيء آخر", "غير ذلك", "أخرى"]
    }

    # English keywords for each topic
    english_keywords = {
        "Customer Service": ["service", "inquiry", "help", "support", "question", "assistance"],
        "résiliation Service": ["retain", "cut", "discount", "stopped", "promotion","stop"],
        "Billing Issue": ["bill", "payment", "problem", "error", "amount"],
        "Other": ["other", "none of the above", "something else"]
    }

    # Select the appropriate keywords based on the language
    if language == 'ar':
        keywords = arabic_keywords
    elif language == 'en':
        keywords = english_keywords
    else:
        raise ValueError("Invalid language specified. Use 'ar' for Arabic or 'en' for English.")

    # Convert text to lowercase to avoid inconsistencies
    text = text.lower()

    # Check for keywords in the text and calculate the topic scores
    topic_scores = {topic: 0 for topic in keywords}  # Initialize topic scores

    for topic, words in keywords.items():
        for word in words:
            if word in text:
                topic_scores[topic] += 1  # Increment score for each keyword found

    # Check if no keywords are found, and in that case, return "Other"
    if all(score == 0 for score in topic_scores.values()):
        return "Other"

    # Return the topic with the highest score
    best_topic = max(topic_scores, key=topic_scores.get)
    return best_topic




def transcribe_audio(audio):
    """Convert audio to text, translate it, and classify topics in both Darija and English."""
    try:
        # Load and preprocess audio
        audio_array, sr = librosa.load(audio, sr=16000)
        input_values = processor(audio_array, sampling_rate=16000, return_tensors="pt", padding=True).input_values.to(device)

        # Transcription (Darija)
        with torch.no_grad():
            logits = wav2vec_model(input_values).logits
        tokens = torch.argmax(logits, axis=-1)
        transcription = processor.decode(tokens[0])

        # Translate to English
        translation = translate_text(transcription)

        # Classify topics using BERT models
        darija_topic = classify_topic(transcription, arabert_tokenizer, arabert_model, darija_topic_labels)
        english_topic = classify_topic(translation, bert_tokenizer, bert_model, english_topic_labels)

        # Classify topics using keywords-based classification
        darija_keyword_topic = classify_topic_by_keywords(transcription,language='ar' )
        english_keyword_topic = classify_topic_by_keywords(translation,language='en' )
        #english_keyword_topic = classify_topic_by_keywords(translation )

        #  l'analyse de sentiment
        sentiment = analyze_sentiment(transcription)

        return transcription, translation, darija_topic, english_topic, darija_keyword_topic, english_keyword_topic,sentiment

    except Exception as e:
        return f"Error processing audio: {str(e)}", "", "", "", "", "", ""

def translate_text(text):
    """Translate Arabic text to English."""
    inputs = translation_tokenizer(text, return_tensors="pt", padding=True, truncation=True).to(device)
    with torch.no_grad():
        translated_tokens = translation_model.generate(**inputs)
    return translation_tokenizer.decode(translated_tokens[0], skip_special_tokens=True)

def classify_topic(text, tokenizer, model, topic_labels):
    """Classify topic using BERT-based models."""
    inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512).to(device)
    with torch.no_grad():
        outputs = model(**inputs)
        predicted_class = torch.argmax(outputs.logits, dim=1).item()

    return topic_labels[predicted_class] if predicted_class < len(topic_labels) else "Other"

def analyze_sentiment(text):
    """Classifie le sentiment du texte en Darija."""
    device = "cuda" if torch.cuda.is_available() else "cpu"
    
    # Tokenizer le texte
    inputs = sentiment_tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512).to(device)
    
    # Prédiction
    with torch.no_grad():
        outputs = sentiment_model(**inputs)
        predicted_class = torch.argmax(outputs.logits, dim=1).item()

    # Retourner la classe correspondante
    return sentiment_labels[predicted_class] if predicted_class < len(sentiment_labels) else "Inconnu"


# 🔹 Gradio Interface
with gr.Blocks() as demo:
    gr.Markdown("# 🎙️ Speech-to-Text, Translation & Topic Classification")
    
    audio_input = gr.Audio(type="filepath", label="Upload Audio or Record")
    submit_button = gr.Button("Process")

    transcription_output = gr.Textbox(label="Transcription (Darija)")
    translation_output = gr.Textbox(label="Translation (English)")
    darija_topic_output = gr.Textbox(label="Darija Topic Classification (BERT)")
    english_topic_output = gr.Textbox(label="English Topic Classification (BERT)")
    darija_keyword_topic_output = gr.Textbox(label="Darija Topic Classification (Keywords)")
    english_keyword_topic_output = gr.Textbox(label="English Topic Classification (Keywords)")
    sentiment_output = gr.Textbox(label="Sentiment (Darija)")
    submit_button.click(transcribe_audio, 
                        inputs=[audio_input], 
                        outputs=[transcription_output, translation_output, 
                                 darija_topic_output, english_topic_output, 
                                 darija_keyword_topic_output, english_keyword_topic_output, sentiment_output])

demo.launch()