|
import gradio as gr |
|
import torch |
|
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC, pipeline |
|
from transformers import BertTokenizer, BertForSequenceClassification |
|
import librosa |
|
|
|
|
|
|
|
processor = Wav2Vec2Processor.from_pretrained("boumehdi/wav2vec2-large-xlsr-moroccan-darija") |
|
transcription_model = Wav2Vec2ForCTC.from_pretrained("boumehdi/wav2vec2-large-xlsr-moroccan-darija") |
|
|
|
|
|
summarizer = pipeline("summarization", model="facebook/bart-large-cnn") |
|
|
|
|
|
topic_model = BertForSequenceClassification.from_pretrained("bert-base-uncased") |
|
topic_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") |
|
|
|
|
|
|
|
def resample_audio(audio_path, target_sr=16000): |
|
audio_input, original_sr = librosa.load(audio_path, sr=None) |
|
if original_sr != target_sr: |
|
audio_input = librosa.resample(audio_input, orig_sr=original_sr, target_sr=target_sr) |
|
return audio_input, target_sr |
|
|
|
|
|
def transcribe_audio(audio_path): |
|
|
|
audio_input, sample_rate = resample_audio(audio_path) |
|
inputs = processor(audio_input, sampling_rate=sample_rate, return_tensors="pt", padding=True) |
|
|
|
|
|
with torch.no_grad(): |
|
logits = transcription_model(**inputs).logits |
|
|
|
|
|
predicted_ids = torch.argmax(logits, dim=-1) |
|
transcription = processor.batch_decode(predicted_ids)[0] |
|
return transcription |
|
|
|
|
|
def classify_topic(transcription): |
|
|
|
inputs = topic_tokenizer(transcription, return_tensors="pt", truncation=True, padding=True, max_length=512) |
|
with torch.no_grad(): |
|
outputs = topic_model(**inputs) |
|
|
|
|
|
predicted_class = torch.argmax(outputs.logits, dim=1).item() |
|
|
|
|
|
if predicted_class == 0: |
|
return "Customer Service" |
|
elif predicted_class == 1: |
|
return "Retention Service" |
|
else: |
|
return "Other" |
|
|
|
|
|
def transcribe_and_summarize(audio_file): |
|
|
|
transcription = transcribe_audio(audio_file) |
|
|
|
|
|
summary = summarizer(transcription, max_length=50, min_length=10, do_sample=False)[0]["summary_text"] |
|
|
|
|
|
topic = classify_topic(transcription) |
|
|
|
return transcription, summary, topic |
|
|
|
|
|
inputs = gr.Audio(type="filepath", label="Upload your audio file") |
|
outputs = [ |
|
gr.Textbox(label="Transcription"), |
|
gr.Textbox(label="Summary"), |
|
gr.Textbox(label="Topic") |
|
] |
|
|
|
app = gr.Interface( |
|
fn=transcribe_and_summarize, |
|
inputs=inputs, |
|
outputs=outputs, |
|
title="Moroccan Darija Audio Transcription, Summarization, and Topic Classification (JABRI)", |
|
description="Upload an audio file in Moroccan Darija to get its transcription, a summarized version of the content, and the detected topic." |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
app.launch() |
|
|