Mohssinibra commited on
Commit
6349c25
·
verified ·
1 Parent(s): 9e3ffca
Files changed (1) hide show
  1. app.py +3 -5
app.py CHANGED
@@ -4,7 +4,7 @@ import torch
4
  from transformers import (
5
  Wav2Vec2ForCTC, Wav2Vec2Processor,
6
  MarianMTModel, MarianTokenizer,
7
- BertForSequenceClassification, AutoModel, AutoTokenizer
8
  )
9
 
10
  # Detect device
@@ -24,13 +24,12 @@ translation_model = MarianMTModel.from_pretrained(translation_model_name).to(dev
24
  # AraBERT for Darija topic classification
25
  arabert_model_name = "aubmindlab/bert-base-arabert"
26
  arabert_tokenizer = AutoTokenizer.from_pretrained(arabert_model_name)
27
- arabert_model = BertForSequenceClassification.from_pretrained(arabert_model_name, num_labels=2).to(device)
28
-
29
 
30
  # BERT for English topic classification
31
  bert_model_name = "bert-base-uncased"
32
  bert_tokenizer = AutoTokenizer.from_pretrained(bert_model_name)
33
- bert_model = BertForSequenceClassification.from_pretrained(bert_model_name, num_labels=3).to(device)
34
 
35
  # Libellés en Darija (Arabe et Latin)
36
  darija_topic_labels = [
@@ -62,7 +61,6 @@ english_topic_labels = [
62
  "Other"
63
  ]
64
 
65
-
66
  def transcribe_audio(audio):
67
  """Convert audio to text, translate it, and classify topics in both Darija and English."""
68
  try:
 
4
  from transformers import (
5
  Wav2Vec2ForCTC, Wav2Vec2Processor,
6
  MarianMTModel, MarianTokenizer,
7
+ BertForSequenceClassification, AutoTokenizer, AutoModel
8
  )
9
 
10
  # Detect device
 
24
  # AraBERT for Darija topic classification
25
  arabert_model_name = "aubmindlab/bert-base-arabert"
26
  arabert_tokenizer = AutoTokenizer.from_pretrained(arabert_model_name)
27
+ arabert_model = BertForSequenceClassification.from_pretrained(arabert_model_name, num_labels=11).to(device) # Adjusted to 11 labels for Darija
 
28
 
29
  # BERT for English topic classification
30
  bert_model_name = "bert-base-uncased"
31
  bert_tokenizer = AutoTokenizer.from_pretrained(bert_model_name)
32
+ bert_model = BertForSequenceClassification.from_pretrained(bert_model_name, num_labels=11).to(device) # Adjusted to 11 labels for English
33
 
34
  # Libellés en Darija (Arabe et Latin)
35
  darija_topic_labels = [
 
61
  "Other"
62
  ]
63
 
 
64
  def transcribe_audio(audio):
65
  """Convert audio to text, translate it, and classify topics in both Darija and English."""
66
  try: