File size: 2,116 Bytes
ec4a7b0
1784a22
c1f16ee
7ecfa8c
9a23b5c
 
45c84cc
99757c1
 
747f8ea
ec4a7b0
 
 
 
 
 
 
 
99757c1
747f8ea
 
 
 
 
 
 
 
 
 
 
9a23b5c
 
 
99757c1
747f8ea
7ecfa8c
 
9365c1c
2d9e152
 
9365c1c
2d9e152
c1f16ee
747f8ea
 
 
47e7f1f
747f8ea
 
2d9e152
747f8ea
 
 
2d9e152
47e7f1f
 
 
 
 
 
 
 
 
 
490e341
 
 
 
 
c1f16ee
 
 
2d9e152
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import os
import pickle

import gradio as gr
from transformers import AutoModel, AutoTokenizer

from .utils import extract_hidden_state


# Load model
models_dir = os.path.join(os.path.dirname(__file__), '..', 'models')
model_file = os.path.join(models_dir, 'logistic_regression.pkl')

if os.path.exists(model_file):
    with open(model_file, "rb") as f:
        model = pickle.load(f)
else:
    print(f"Error: {model_file} not found.")

# Load html
html_dir = os.path.join(os.path.dirname(__file__), "templates")
index_html_path = os.path.join(html_dir, "index.html")

if os.path.exists(index_html_path):
    with open(index_html_path, "r") as html_file:
        index_html = html_file.read()
else:
    print(f"Error: {index_html_path} not found.")

# Load pre-trained model
model_name = "moussaKam/AraBART"
tokenizer = AutoTokenizer.from_pretrained(model_name)
language_model = AutoModel.from_pretrained(model_name)


def classify_arabic_dialect(text):
    text_embeddings = extract_hidden_state(text, tokenizer, language_model)
    probabilities = model.predict_proba(text_embeddings)[0]
    labels = model.classes_
    predictions = {labels[i]: probabilities[i] for i in range(len(probabilities))}

    return predictions


with gr.Blocks() as demo:
    gr.HTML(index_html)

    input_text = gr.Textbox(label="Your Arabic Text")
    submit_btn = gr.Button("Submit")
    predictions = gr.Label(num_top_classes=3)
    submit_btn.click(
        fn=classify_arabic_dialect, 
        inputs=input_text, 
        outputs=predictions)
    
    gr.Markdown("## Text Examples")
    examples = gr.Examples(
        examples=[
            "واش نتا خدام ولا لا",
            "بصح راك فاهم لازم الزيت",
            "حضرتك بروح زي كدا؟ على طول النهار ده",
        ],
        inputs=input_text,
    )
    gr.HTML("""
            <p style="text-align: center;font-size: large;">
            Checkout the <a href="https://github.com/zaidmehdi/arabic-dialect-classifier">Github Repo</a>
            </p>
            """)


if __name__ == "__main__":
    demo.launch()