zaidmehdi commited on
Commit
2d9e152
·
1 Parent(s): 46e333b

display classification labels

Browse files
Files changed (1) hide show
  1. src/main.py +6 -13
src/main.py CHANGED
@@ -37,28 +37,21 @@ language_model = AutoModel.from_pretrained(model_name)
37
  def classify_arabic_dialect(text):
38
  text_embeddings = extract_hidden_state(text, tokenizer, language_model)
39
  probabilities = model.predict_proba(text_embeddings)[0]
40
- top_three_indices = np.argsort(-probabilities)[:3]
 
41
 
42
- top_three_labels = model.classes_[top_three_indices]
43
- top_three_probabilities = probabilities[top_three_indices]
44
-
45
- return (top_three_labels[0], top_three_probabilities[0]),\
46
- (top_three_labels[1], top_three_probabilities[1]),\
47
- (top_three_labels[2], top_three_probabilities[2])
48
 
49
 
50
  with gr.Blocks() as demo:
51
  gr.HTML(index_html)
52
  input_text = gr.Textbox(label="Your Arabic Text")
53
  submit_btn = gr.Button("Submit")
54
- with gr.Row():
55
- first_country = gr.Textbox()
56
- second_country = gr.Textbox()
57
- third_country = gr.Textbox()
58
  submit_btn.click(
59
  fn=classify_arabic_dialect,
60
  inputs=input_text,
61
- outputs=[first_country, second_country, third_country])
62
  gr.HTML("""
63
  <p style="text-align: center;font-size: large;">
64
  Checkout the <a href="https://github.com/zaidmehdi/arabic-dialect-classifier">Github Repo</a>
@@ -67,4 +60,4 @@ with gr.Blocks() as demo:
67
 
68
 
69
  if __name__ == "__main__":
70
- demo.launch()
 
37
  def classify_arabic_dialect(text):
38
  text_embeddings = extract_hidden_state(text, tokenizer, language_model)
39
  probabilities = model.predict_proba(text_embeddings)[0]
40
+ labels = model.classes_
41
+ predictions = {labels[i]: probabilities[i] for i in range(len(probabilities))}
42
 
43
+ return predictions
 
 
 
 
 
44
 
45
 
46
  with gr.Blocks() as demo:
47
  gr.HTML(index_html)
48
  input_text = gr.Textbox(label="Your Arabic Text")
49
  submit_btn = gr.Button("Submit")
50
+ predictions = gr.Label(num_top_classes=3)
 
 
 
51
  submit_btn.click(
52
  fn=classify_arabic_dialect,
53
  inputs=input_text,
54
+ outputs=predictions)
55
  gr.HTML("""
56
  <p style="text-align: center;font-size: large;">
57
  Checkout the <a href="https://github.com/zaidmehdi/arabic-dialect-classifier">Github Repo</a>
 
60
 
61
 
62
  if __name__ == "__main__":
63
+ demo.launch()