import gradio as gr import torch from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC, pipeline from transformers import BertTokenizer, BertForSequenceClassification import librosa # Load models # Transcription model for Moroccan Darija processor = Wav2Vec2Processor.from_pretrained("boumehdi/wav2vec2-large-xlsr-moroccan-darija") transcription_model = Wav2Vec2ForCTC.from_pretrained("boumehdi/wav2vec2-large-xlsr-moroccan-darija") # Summarization model summarizer = pipeline("summarization", model="facebook/bart-large-cnn") # Topic Classification Model (BERT for example) topic_model = BertForSequenceClassification.from_pretrained("bert-base-uncased") # Example model topic_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") # Function to resample audio to 16kHz if necessary def resample_audio(audio_path, target_sr=16000): audio_input, original_sr = librosa.load(audio_path, sr=None) # Load audio with original sampling rate if original_sr != target_sr: audio_input = librosa.resample(audio_input, orig_sr=original_sr, target_sr=target_sr) # Resample to 16kHz return audio_input, target_sr # Function to transcribe audio using Wav2Vec2 def transcribe_audio(audio_path): # Load and preprocess audio audio_input, sample_rate = resample_audio(audio_path) inputs = processor(audio_input, sampling_rate=sample_rate, return_tensors="pt", padding=True) # Get predictions with torch.no_grad(): logits = transcription_model(**inputs).logits # Decode predictions predicted_ids = torch.argmax(logits, dim=-1) transcription = processor.batch_decode(predicted_ids)[0] return transcription # Function to classify the transcription into topics def classify_topic(transcription): # Tokenize the transcription and pass it through the BERT classifier inputs = topic_tokenizer(transcription, return_tensors="pt", truncation=True, padding=True, max_length=512) with torch.no_grad(): outputs = topic_model(**inputs) # Get the predicted label (0 for Customer Service, 1 for Retention Service, etc.) predicted_class = torch.argmax(outputs.logits, dim=1).item() # Map prediction to a topic if predicted_class == 0: return "Customer Service" elif predicted_class == 1: return "Retention Service" else: return "Other" # Function to transcribe, summarize, and classify topic def transcribe_and_summarize(audio_file): # Transcription transcription = transcribe_audio(audio_file) # Summarization summary = summarizer(transcription, max_length=50, min_length=10, do_sample=False)[0]["summary_text"] # Topic classification topic = classify_topic(transcription) return transcription, summary, topic # Gradio Interface inputs = gr.Audio(type="filepath", label="Upload your audio file") outputs = [ gr.Textbox(label="Transcription"), gr.Textbox(label="Summary"), gr.Textbox(label="Topic") ] app = gr.Interface( fn=transcribe_and_summarize, inputs=inputs, outputs=outputs, title="Moroccan Darija Audio Transcription, Summarization, and Topic Classification (JABRI)", description="Upload an audio file in Moroccan Darija to get its transcription, a summarized version of the content, and the detected topic." ) # Launch the app if __name__ == "__main__": app.launch()