Mohssinibra commited on
Commit
f117ae1
·
verified ·
1 Parent(s): f2ecb6e
Files changed (1) hide show
  1. app.py +31 -10
app.py CHANGED
@@ -28,22 +28,43 @@ darija_topic_labels = ["Customer Service", "Retention Service", "Billing Issue"]
28
  english_topic_labels = ["Support Request", "Subscription Issue", "Payment Dispute"] # Adjust for English topics
29
 
30
 
 
 
31
  def transcribe_audio(audio):
32
  """Convert audio to text, translate it, and classify topics in both Darija and English"""
33
- audio_array, sr = librosa.load(audio, sr=16000)
34
- input_values = processor(audio_array, return_tensors="pt", padding=True).input_values
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
- logits = model(input_values).logits
37
- tokens = torch.argmax(logits, axis=-1)
 
38
 
39
- transcription = processor.decode(tokens[0])
40
- translation = translate_text(transcription)
41
 
42
- # Classify topics for both Darija and English
43
- darija_topic = classify_topic(transcription, arabert_tokenizer, arabert_model, darija_topic_labels)
44
- english_topic = classify_topic(translation, bert_tokenizer, bert_model, english_topic_labels)
45
 
46
- return transcription, translation, darija_topic, english_topic
47
 
48
 
49
  def translate_text(text):
 
28
  english_topic_labels = ["Support Request", "Subscription Issue", "Payment Dispute"] # Adjust for English topics
29
 
30
 
31
+ import torch
32
+
33
  def transcribe_audio(audio):
34
  """Convert audio to text, translate it, and classify topics in both Darija and English"""
35
+ try:
36
+ # Load and preprocess audio
37
+ audio_array, sr = librosa.load(audio, sr=16000)
38
+
39
+ # Ensure correct sampling rate
40
+ input_values = processor(audio_array, sampling_rate=16000, return_tensors="pt", padding=True).input_values
41
+
42
+ # Move to GPU if available
43
+ device = "cuda" if torch.cuda.is_available() else "cpu"
44
+ model.to(device)
45
+ input_values = input_values.to(device)
46
+
47
+ # Get predictions from Wav2Vec2 model
48
+ with torch.no_grad():
49
+ logits = model(input_values).logits
50
+ tokens = torch.argmax(logits, axis=-1)
51
+
52
+ # Decode transcription (Darija)
53
+ transcription = processor.decode(tokens[0])
54
+
55
+ # Translate to English
56
+ translation = translate_text(transcription)
57
 
58
+ # Classify topics for Darija and English
59
+ darija_topic = classify_topic(transcription, arabert_tokenizer, arabert_model, darija_topic_labels)
60
+ english_topic = classify_topic(translation, bert_tokenizer, bert_model, english_topic_labels)
61
 
62
+ return transcription, translation, darija_topic, english_topic
 
63
 
64
+ except Exception as e:
65
+ print(f"Error in transcription: {e}")
66
+ return "Error processing audio", "", "", ""
67
 
 
68
 
69
 
70
  def translate_text(text):