{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Arabic Dialect Classifier\n", "This notebook contains the training of the classifier model. The goal is to classify the dialects at the country level." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "import pickle\n", "\n", "from datasets import DatasetDict, Dataset\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import pandas as pd\n", "from sklearn.ensemble import RandomForestClassifier\n", "from sklearn.linear_model import LogisticRegression\n", "from sklearn.model_selection import RandomizedSearchCV\n", "from sklearn.preprocessing import LabelEncoder\n", "import torch\n", "from transformers import AutoModel, AutoTokenizer\n", "import xgboost as xgb" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 1. Exploring the Dataset" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "df_train = pd.read_csv(\"../data/DA_train_labeled.tsv\", sep=\"\\t\")\n", "df_test = pd.read_csv(\"../data/DA_dev_labeled.tsv\", sep=\"\\t\")" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", " | #1_tweetid | \n", "#2_tweet | \n", "#3_country_label | \n", "#4_province_label | \n", "
---|---|---|---|---|
0 | \n", "TRAIN_0 | \n", "حاجة حلوة اكيد | \n", "Egypt | \n", "eg_Faiyum | \n", "
1 | \n", "TRAIN_1 | \n", "عم بشتغلوا للشعب الاميركي اما نحن يكذبوا ويغشو... | \n", "Iraq | \n", "iq_Dihok | \n", "
2 | \n", "TRAIN_2 | \n", "ابشر طال عمرك | \n", "Saudi_Arabia | \n", "sa_Ha'il | \n", "
3 | \n", "TRAIN_3 | \n", "منطق 2017: أنا والغريب علي إبن عمي وأنا والغري... | \n", "Mauritania | \n", "mr_Nouakchott | \n", "
4 | \n", "TRAIN_4 | \n", "شهرين وتروح والباقي غير صيف ملينا | \n", "Algeria | \n", "dz_El-Oued | \n", "
\n", " | #1_tweetid | \n", "#2_tweet | \n", "#3_country_label | \n", "#4_province_label | \n", "
---|---|---|---|---|
0 | \n", "DEV_0 | \n", "قولنا اون لاين لا يا علي اون لاين لا | \n", "Egypt | \n", "eg_Alexandria | \n", "
1 | \n", "DEV_1 | \n", "ههههه بايخه ههههه URL … | \n", "Oman | \n", "om_Muscat | \n", "
2 | \n", "DEV_2 | \n", "ربنا يخليك يا دوك ولك المثل :D | \n", "Lebanon | \n", "lb_South-Lebanon | \n", "
3 | \n", "DEV_3 | \n", "#اوامر_ملكيه ياشباب اي واحد فيكم عنده شي يذكره... | \n", "Syria | \n", "sy_Damascus-City | \n", "
4 | \n", "DEV_4 | \n", "شد عالخط حتى هيا اكويسه | \n", "Libya | \n", "ly_Misrata | \n", "
LogisticRegression(class_weight='balanced', max_iter=1000,\n", " multi_class='multinomial', random_state=2024)In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
LogisticRegression(class_weight='balanced', max_iter=1000,\n", " multi_class='multinomial', random_state=2024)
RandomizedSearchCV(cv=5,\n", " estimator=RandomForestClassifier(class_weight='balanced',\n", " random_state=2024),\n", " n_iter=20,\n", " param_distributions={'max_depth': [3, 4, 5, 6, 7, 8],\n", " 'n_estimators': [100, 150, 200, 250,\n", " 300, 400, 500]},\n", " scoring='f1_macro')In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
RandomizedSearchCV(cv=5,\n", " estimator=RandomForestClassifier(class_weight='balanced',\n", " random_state=2024),\n", " n_iter=20,\n", " param_distributions={'max_depth': [3, 4, 5, 6, 7, 8],\n", " 'n_estimators': [100, 150, 200, 250,\n", " 300, 400, 500]},\n", " scoring='f1_macro')
RandomForestClassifier(class_weight='balanced', random_state=2024)
RandomForestClassifier(class_weight='balanced', random_state=2024)