File size: 3,376 Bytes
b789fdf 6c74174 3f47ed7 7fd6b33 b789fdf cf75eeb 6c74174 8663026 6c74174 deca047 3f47ed7 fcf4167 deca047 3f47ed7 b789fdf 7fd6b33 6c74174 3f47ed7 7fd6b33 3f47ed7 6c74174 3f47ed7 6c74174 3f47ed7 3785854 3f47ed7 3785854 7fd6b33 61e1f2a deca047 3f47ed7 3785854 6c74174 3785854 deca047 3785854 3f47ed7 3785854 cf75eeb ff9c2e5 3785854 ff9c2e5 deca047 3785854 ff9c2e5 cf75eeb ff9c2e5 3f47ed7 ff9c2e5 021133b deca047 cf75eeb ff9c2e5 6c74174 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 |
import gradio as gr
import torch
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC, pipeline
from transformers import BertTokenizer, BertForSequenceClassification
import librosa
# Load models
# Transcription model for Moroccan Darija
processor = Wav2Vec2Processor.from_pretrained("boumehdi/wav2vec2-large-xlsr-moroccan-darija")
transcription_model = Wav2Vec2ForCTC.from_pretrained("boumehdi/wav2vec2-large-xlsr-moroccan-darija")
# Summarization model
summarizer = pipeline("summarization", model="facebook/bart-large-cnn")
# Topic Classification Model (BERT for example)
topic_model = BertForSequenceClassification.from_pretrained("bert-base-uncased") # Example model
topic_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
# Function to resample audio to 16kHz if necessary
def resample_audio(audio_path, target_sr=16000):
audio_input, original_sr = librosa.load(audio_path, sr=None) # Load audio with original sampling rate
if original_sr != target_sr:
audio_input = librosa.resample(audio_input, orig_sr=original_sr, target_sr=target_sr) # Resample to 16kHz
return audio_input, target_sr
# Function to transcribe audio using Wav2Vec2
def transcribe_audio(audio_path):
# Load and preprocess audio
audio_input, sample_rate = resample_audio(audio_path)
inputs = processor(audio_input, sampling_rate=sample_rate, return_tensors="pt", padding=True)
# Get predictions
with torch.no_grad():
logits = transcription_model(**inputs).logits
# Decode predictions
predicted_ids = torch.argmax(logits, dim=-1)
transcription = processor.batch_decode(predicted_ids)[0]
return transcription
# Function to classify the transcription into topics
def classify_topic(transcription):
# Tokenize the transcription and pass it through the BERT classifier
inputs = topic_tokenizer(transcription, return_tensors="pt", truncation=True, padding=True, max_length=512)
with torch.no_grad():
outputs = topic_model(**inputs)
# Get the predicted label (0 for Customer Service, 1 for Retention Service, etc.)
predicted_class = torch.argmax(outputs.logits, dim=1).item()
# Map prediction to a topic
if predicted_class == 0:
return "Customer Service"
elif predicted_class == 1:
return "Retention Service"
else:
return "Other"
# Function to transcribe, summarize, and classify topic
def transcribe_and_summarize(audio_file):
# Transcription
transcription = transcribe_audio(audio_file)
# Summarization
summary = summarizer(transcription, max_length=50, min_length=10, do_sample=False)[0]["summary_text"]
# Topic classification
topic = classify_topic(transcription)
return transcription, summary, topic
# Gradio Interface
inputs = gr.Audio(type="filepath", label="Upload your audio file")
outputs = [
gr.Textbox(label="Transcription"),
gr.Textbox(label="Summary"),
gr.Textbox(label="Topic")
]
app = gr.Interface(
fn=transcribe_and_summarize,
inputs=inputs,
outputs=outputs,
title="Moroccan Darija Audio Transcription, Summarization, and Topic Classification (JABRI)",
description="Upload an audio file in Moroccan Darija to get its transcription, a summarized version of the content, and the detected topic."
)
# Launch the app
if __name__ == "__main__":
app.launch()
|