Spaces:
Sleeping
Sleeping
display classification labels
Browse files- src/main.py +6 -13
src/main.py
CHANGED
@@ -37,28 +37,21 @@ language_model = AutoModel.from_pretrained(model_name)
|
|
37 |
def classify_arabic_dialect(text):
|
38 |
text_embeddings = extract_hidden_state(text, tokenizer, language_model)
|
39 |
probabilities = model.predict_proba(text_embeddings)[0]
|
40 |
-
|
|
|
41 |
|
42 |
-
|
43 |
-
top_three_probabilities = probabilities[top_three_indices]
|
44 |
-
|
45 |
-
return (top_three_labels[0], top_three_probabilities[0]),\
|
46 |
-
(top_three_labels[1], top_three_probabilities[1]),\
|
47 |
-
(top_three_labels[2], top_three_probabilities[2])
|
48 |
|
49 |
|
50 |
with gr.Blocks() as demo:
|
51 |
gr.HTML(index_html)
|
52 |
input_text = gr.Textbox(label="Your Arabic Text")
|
53 |
submit_btn = gr.Button("Submit")
|
54 |
-
|
55 |
-
first_country = gr.Textbox()
|
56 |
-
second_country = gr.Textbox()
|
57 |
-
third_country = gr.Textbox()
|
58 |
submit_btn.click(
|
59 |
fn=classify_arabic_dialect,
|
60 |
inputs=input_text,
|
61 |
-
outputs=
|
62 |
gr.HTML("""
|
63 |
<p style="text-align: center;font-size: large;">
|
64 |
Checkout the <a href="https://github.com/zaidmehdi/arabic-dialect-classifier">Github Repo</a>
|
@@ -67,4 +60,4 @@ with gr.Blocks() as demo:
|
|
67 |
|
68 |
|
69 |
if __name__ == "__main__":
|
70 |
-
demo.launch()
|
|
|
37 |
def classify_arabic_dialect(text):
|
38 |
text_embeddings = extract_hidden_state(text, tokenizer, language_model)
|
39 |
probabilities = model.predict_proba(text_embeddings)[0]
|
40 |
+
labels = model.classes_
|
41 |
+
predictions = {labels[i]: probabilities[i] for i in range(len(probabilities))}
|
42 |
|
43 |
+
return predictions
|
|
|
|
|
|
|
|
|
|
|
44 |
|
45 |
|
46 |
with gr.Blocks() as demo:
|
47 |
gr.HTML(index_html)
|
48 |
input_text = gr.Textbox(label="Your Arabic Text")
|
49 |
submit_btn = gr.Button("Submit")
|
50 |
+
predictions = gr.Label(num_top_classes=3)
|
|
|
|
|
|
|
51 |
submit_btn.click(
|
52 |
fn=classify_arabic_dialect,
|
53 |
inputs=input_text,
|
54 |
+
outputs=predictions)
|
55 |
gr.HTML("""
|
56 |
<p style="text-align: center;font-size: large;">
|
57 |
Checkout the <a href="https://github.com/zaidmehdi/arabic-dialect-classifier">Github Repo</a>
|
|
|
60 |
|
61 |
|
62 |
if __name__ == "__main__":
|
63 |
+
demo.launch()
|