{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Arabic Dialect Classifier\n", "This notebook contains the training of the classifier model. The goal is to classify the dialects at the country level." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/mehdi/miniconda3/envs/adc/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n" ] } ], "source": [ "import pickle\n", "\n", "from datasets import DatasetDict, Dataset\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import pandas as pd\n", "from sklearn.ensemble import RandomForestClassifier\n", "from sklearn.linear_model import LogisticRegression\n", "from sklearn.model_selection import RandomizedSearchCV\n", "from sklearn.preprocessing import LabelEncoder\n", "import torch\n", "from transformers import AutoModel, AutoTokenizer\n", "import xgboost as xgb" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 1. Exploring the Dataset" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "df_train = pd.read_csv(\"../data/DA_train_labeled.tsv\", sep=\"\\t\")\n", "df_test = pd.read_csv(\"../data/DA_dev_labeled.tsv\", sep=\"\\t\")" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", " | #1_tweetid | \n", "#2_tweet | \n", "#3_country_label | \n", "#4_province_label | \n", "
---|---|---|---|---|
0 | \n", "TRAIN_0 | \n", "حاجة حلوة اكيد | \n", "Egypt | \n", "eg_Faiyum | \n", "
1 | \n", "TRAIN_1 | \n", "عم بشتغلوا للشعب الاميركي اما نحن يكذبوا ويغشو... | \n", "Iraq | \n", "iq_Dihok | \n", "
2 | \n", "TRAIN_2 | \n", "ابشر طال عمرك | \n", "Saudi_Arabia | \n", "sa_Ha'il | \n", "
3 | \n", "TRAIN_3 | \n", "منطق 2017: أنا والغريب علي إبن عمي وأنا والغري... | \n", "Mauritania | \n", "mr_Nouakchott | \n", "
4 | \n", "TRAIN_4 | \n", "شهرين وتروح والباقي غير صيف ملينا | \n", "Algeria | \n", "dz_El-Oued | \n", "
\n", " | #1_tweetid | \n", "#2_tweet | \n", "#3_country_label | \n", "#4_province_label | \n", "
---|---|---|---|---|
0 | \n", "DEV_0 | \n", "قولنا اون لاين لا يا علي اون لاين لا | \n", "Egypt | \n", "eg_Alexandria | \n", "
1 | \n", "DEV_1 | \n", "ههههه بايخه ههههه URL … | \n", "Oman | \n", "om_Muscat | \n", "
2 | \n", "DEV_2 | \n", "ربنا يخليك يا دوك ولك المثل :D | \n", "Lebanon | \n", "lb_South-Lebanon | \n", "
3 | \n", "DEV_3 | \n", "#اوامر_ملكيه ياشباب اي واحد فيكم عنده شي يذكره... | \n", "Syria | \n", "sy_Damascus-City | \n", "
4 | \n", "DEV_4 | \n", "شد عالخط حتى هيا اكويسه | \n", "Libya | \n", "ly_Misrata | \n", "
LogisticRegression(class_weight='balanced', max_iter=1000,\n", " multi_class='multinomial', random_state=2024)In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
LogisticRegression(class_weight='balanced', max_iter=1000,\n", " multi_class='multinomial', random_state=2024)
RandomizedSearchCV(cv=5,\n", " estimator=RandomForestClassifier(class_weight='balanced',\n", " random_state=2024),\n", " n_iter=20,\n", " param_distributions={'max_depth': [3, 4, 5, 6, 7, 8],\n", " 'n_estimators': [100, 150, 200, 250,\n", " 300, 400, 500]},\n", " scoring='f1_macro')In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
RandomizedSearchCV(cv=5,\n", " estimator=RandomForestClassifier(class_weight='balanced',\n", " random_state=2024),\n", " n_iter=20,\n", " param_distributions={'max_depth': [3, 4, 5, 6, 7, 8],\n", " 'n_estimators': [100, 150, 200, 250,\n", " 300, 400, 500]},\n", " scoring='f1_macro')
RandomForestClassifier(class_weight='balanced', random_state=2024)
RandomForestClassifier(class_weight='balanced', random_state=2024)
LogisticRegression(class_weight='balanced', max_iter=1000,\n", " multi_class='multinomial', random_state=2024)In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
LogisticRegression(class_weight='balanced', max_iter=1000,\n", " multi_class='multinomial', random_state=2024)
RandomForestClassifier(class_weight='balanced', max_depth=8, n_estimators=400,\n", " random_state=2024)In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
RandomForestClassifier(class_weight='balanced', max_depth=8, n_estimators=400,\n", " random_state=2024)
XGBClassifier(base_score=None, booster=None, callbacks=None,\n", " colsample_bylevel=None, colsample_bynode=None,\n", " colsample_bytree=None, device='cuda', early_stopping_rounds=None,\n", " enable_categorical=False, eval_metric=None, feature_types=None,\n", " gamma=None, grow_policy=None, importance_type=None,\n", " interaction_constraints=None, learning_rate=0.1, max_bin=None,\n", " max_cat_threshold=None, max_cat_to_onehot=None,\n", " max_delta_step=None, max_depth=7, max_leaves=None,\n", " min_child_weight=None, missing=nan, monotone_constraints=None,\n", " multi_strategy=None, n_estimators=450, n_jobs=None,\n", " num_parallel_tree=None, objective='multi:softprob', ...)In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
XGBClassifier(base_score=None, booster=None, callbacks=None,\n", " colsample_bylevel=None, colsample_bynode=None,\n", " colsample_bytree=None, device='cuda', early_stopping_rounds=None,\n", " enable_categorical=False, eval_metric=None, feature_types=None,\n", " gamma=None, grow_policy=None, importance_type=None,\n", " interaction_constraints=None, learning_rate=0.1, max_bin=None,\n", " max_cat_threshold=None, max_cat_to_onehot=None,\n", " max_delta_step=None, max_depth=7, max_leaves=None,\n", " min_child_weight=None, missing=nan, monotone_constraints=None,\n", " multi_strategy=None, n_estimators=450, n_jobs=None,\n", " num_parallel_tree=None, objective='multi:softprob', ...)