{ "cells": [ { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "import os\n", "import json\n", "import pickle\n", "from time import time\n", "from tqdm import tqdm\n", "from typing import Union, List, Optional\n", "\n", "import math\n", "import torch\n", "from torch import nn\n", "from torch.nn import functional as F\n", "from xgboost import XGBClassifier\n", "import numpy as np\n", "from nltk.tokenize import wordpunct_tokenize\n", "from sklearn.metrics import roc_curve, accuracy_score, f1_score\n", "from sklearn.utils import shuffle\n", "\n", "import sys\n", "ROOT = os.path.join(\"..\", \"..\", \"QuasarNix\")\n", "sys.path.append(ROOT)\n", "# NOTE: src is from https://github.com/dtrizna/QuasarNix -- temp, will be as pip later\n", "from src.data_utils import commands_to_loader, load_nl2bash\n", "from src.preprocessors import CommandTokenizer, OneHotCustomVectorizer\n", "\n", "TOKENIZER = wordpunct_tokenize\n", "SEED = 33\n", "VOCAB_SIZE = 4096\n", "EMBEDDED_DIM = 64\n", "DROPOUT = 0.5\n", "MAX_LEN = 128\n", "\n", "def lit_ckpt_to_torch(ckpt: str):\n", " \"\"\"\n", " Convert a lightning checkpoint to a torch state dict\n", " \"\"\"\n", " state_dict = torch.load(ckpt, map_location='cpu')['state_dict']\n", " \n", " for k, v in dict(state_dict).items():\n", " # lightning introduced model. prefix\n", " if k.startswith('model.'):\n", " state_dict[k[len('model.'):]] = v\n", " del state_dict[k]\n", "\n", " return state_dict\n", "\n", "def load_torch(path: str, model: nn.Module):\n", " if path.endswith('.ckpt'):\n", " state_dict = lit_ckpt_to_torch(path)\n", " model.load_state_dict(state_dict)\n", " torch.save(model.state_dict(), path.replace('.ckpt', '.torch'))\n", " elif path.endswith('.torch'):\n", " state_dict = torch.load(path, map_location='cpu')\n", " model.load_state_dict(state_dict)\n", " else:\n", " raise ValueError('Unknown model format')\n", " return model\n", "\n", "def load_xgb(path: str):\n", " if path.endswith('.pickle'):\n", " with open(path, 'rb') as f:\n", " model = pickle.load(f)\n", " model.save_model(path.replace('.pickle', '.xgboost'))\n", " elif path.endswith('.xgboost'):\n", " model = XGBClassifier(n_estimators=100, max_depth=10, random_state=SEED)\n", " model.load_model(path)\n", " else:\n", " raise ValueError('Unknown model format')\n", " return model" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "X_test_baseline_cmds = load_nl2bash()\n", "\n", "# NOTE: malicious data is from https://huggingface.co/datasets/dtrizna/QuasarNix\n", "# orig test\n", "X_test_malicious_cmds_path = os.path.join(ROOT, \"data\", \"X_test_malicious_cmd_orig.json\")\n", "with open(X_test_malicious_cmds_path, 'r') as f:\n", " X_test_malicious_cmds = json.load(f)\n", "\n", "X_test_cmds_orig = X_test_malicious_cmds + X_test_baseline_cmds\n", "y_test_orig = np.array([1]*len(X_test_malicious_cmds) + [0]*len(X_test_baseline_cmds))\n", "X_test_cmds_orig, y_test_orig = shuffle(X_test_cmds_orig, y_test_orig, random_state=SEED)\n", "\n", "# adversarial test\n", "X_test_malicious_cmds_adv_path = os.path.join(ROOT, \"data\", \"X_test_malicious_cmd_adv.json\")\n", "with open(X_test_malicious_cmds_adv_path, 'r') as f:\n", " X_test_malicious_cmds_adv = json.load(f)\n", "\n", "X_test_cmds_adv = X_test_malicious_cmds_adv + X_test_baseline_cmds\n", "y_test_adv = np.array([1]*len(X_test_malicious_cmds_adv) + [0]*len(X_test_baseline_cmds))\n", "X_test_cmds_adv, y_test_adv = shuffle(X_test_cmds_adv, y_test_adv, random_state=SEED)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "# TABULAR MODELS\n", "\n", "oh_tokenizer_orig = OneHotCustomVectorizer(tokenizer=TOKENIZER, max_features=VOCAB_SIZE)\n", "oh_tokenizer_orig.load_vocab(\"quasarnix_tokenizer_data_train_onehot_orig.json\")\n", "\n", "oh_tokenizer_adv = OneHotCustomVectorizer(tokenizer=TOKENIZER, max_features=VOCAB_SIZE)\n", "oh_tokenizer_adv.load_vocab(\"quasarnix_tokenizer_data_train_onehot_adv.json\")\n", "\n", "oh_tokenizer_full = OneHotCustomVectorizer(tokenizer=TOKENIZER, max_features=VOCAB_SIZE)\n", "oh_tokenizer_full.load_vocab(\"quasarnix_tokenizer_data_train_onehot_full.json\")\n", "\n", "xgb_model_path_orig = './quasarnix_model_data_train_xgb_orig.xgboost'\n", "xgb_model_path_adv = './quasarnix_model_data_train_xgb_adv.xgboost'\n", "xgb_model_path_full = './quasarnix_model_data_full_xgb_adv.xgboost'\n", "\n", "mlp_model_path_orig = './quasarnix_model_data_train_mlp_orig.torch'\n", "mlp_model_path_adv = './quasarnix_model_data_train_mlp_adv.torch'\n", "mlp_model_path_full = './quasarnix_model_data_full_mlp_adv.torch'\n", "\n", "# SEQUENTIAL EMBEDDING MODELS\n", "\n", "vocab_path_orig = \"./quasarnix_tokenizer_data_train_vocab_orig.json\"\n", "tokenizer_orig = CommandTokenizer(tokenizer_fn=TOKENIZER, vocab_size=VOCAB_SIZE, max_len=MAX_LEN)\n", "tokenizer_orig.load_vocab(vocab_path_orig)\n", "\n", "vocab_path_adv = \"./quasarnix_tokenizer_data_train_vocab_adv.json\"\n", "tokenizer_adv = CommandTokenizer(tokenizer_fn=TOKENIZER, vocab_size=VOCAB_SIZE, max_len=MAX_LEN)\n", "tokenizer_adv.load_vocab(vocab_path_adv)\n", "\n", "vocab_path_full = \"./quasarnix_tokenizer_data_full_vocab_adv.json\"\n", "tokenizer_full = CommandTokenizer(tokenizer_fn=TOKENIZER, vocab_size=VOCAB_SIZE, max_len=MAX_LEN)\n", "tokenizer_full.load_vocab(vocab_path_full)\n", "\n", "cnn_model_path_orig = './quasarnix_model_data_train_cnn_orig.torch'\n", "cnn_model_path_adv = './quasarnix_model_data_train_cnn_adv.torch'\n", "cnn_model_path_full = './quasarnix_model_data_full_cnn_adv.torch'\n", "\n", "transformer_model_path_orig = './quasarnix_model_data_train_transformer_orig.torch'\n", "transformer_model_path_adv = './quasarnix_model_data_train_transformer_adv.torch'\n", "transformer_model_path_full = './quasarnix_model_data_full_transformer_adv.torch'" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "def get_preds(\n", " model: Union[XGBClassifier, nn.Module],\n", " X_cmds: List[str],\n", " y: np.ndarray,\n", " tokenizer: Union[OneHotCustomVectorizer, CommandTokenizer],\n", " threshold: Optional[float] = None\n", "):\n", " now = time()\n", " if isinstance(model, XGBClassifier):\n", " print(f\"[*] Working on {len(X_cmds)} samples\", end='\\r')\n", " X_encoded = tokenizer.transform(X_cmds)\n", " y_pred = model.predict_proba(X_encoded)[:, 1]\n", " elif isinstance(model, nn.Module):\n", " print(f\"[*] Building DataLoader for {len(X_cmds)} samples\", end='\\r')\n", " loader = commands_to_loader(X_cmds, tokenizer, y=y, workers=4, batch_size=256)\n", " model = model.to('cuda')\n", " model.eval()\n", " with torch.no_grad():\n", " y_pred = []\n", " for (x, _) in tqdm(loader, desc=\"[*] Predicting\", total=math.ceil(len(X_cmds)/256)):\n", " x = x.to('cuda')\n", " y_pred.append(model(x).cpu().numpy())\n", " y_pred = np.concatenate(y_pred).flatten()\n", " else:\n", " raise ValueError(\"Unknown model type\")\n", " \n", " if threshold is not None:\n", " y_pred = y_pred > threshold\n", " acc = accuracy_score(y, y_pred > 0.5)\n", " f1 = f1_score(y, y_pred > 0.5)\n", " print(f\"[!] Accuracy: {acc*100:.3f}%, F1: {f1*100:.3f}% | Took {time()-now:.2f}s\")\n", " \n", " return y_pred, y\n", "\n", "def score_both_sets(model, tokenizer):\n", " print(\"Original Test Set:\")\n", " y_pred_orig, y_true_orig = get_preds(model, X_test_cmds_orig, y_test_orig, tokenizer, threshold=0.5)\n", " print(\"\\nAdversarial Test Set:\")\n", " y_pred_adv, y_true_adv = get_preds(model, X_test_cmds_adv, y_test_adv, tokenizer, threshold=0.5)\n", " return y_pred_orig, y_true_orig, y_pred_adv, y_true_adv\n", "\n", "def plot_roc(y_true, y_pred, ax):\n", " fpr, tpr, _ = roc_curve(y_true, y_pred)\n", " ax.plot(fpr, tpr)\n", " return ax" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Gradient Boosted Decision Trees (GBDT) with XGBoost" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Original Test Set:\n", "[!] Accuracy: 99.968%, F1: 99.968% | Took 25.27s\n", "\n", "Adversarial Test Set:\n", "[!] Accuracy: 83.418%, F1: 80.123% | Took 31.22s\n" ] } ], "source": [ "xgb_orig = load_xgb(xgb_model_path_orig)\n", "_ = score_both_sets(xgb_orig, oh_tokenizer_orig)\n", "\n", "# see params\n", "# print(xgb_full_adv.get_booster().get_dump()[0])" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Original Test Set:\n", "[*] Working on 470129 samples\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[!] Accuracy: 99.954%, F1: 99.954% | Took 24.01s\n", "\n", "Adversarial Test Set:\n", "[!] Accuracy: 99.975%, F1: 99.975% | Took 32.38s\n" ] } ], "source": [ "xgb_adv = load_xgb(xgb_model_path_adv)\n", "_ = score_both_sets(xgb_adv, oh_tokenizer_adv)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Original Test Set:\n", "[*] Working on 470129 samples\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[!] Accuracy: 100.000%, F1: 100.000% | Took 27.10s\n", "\n", "Adversarial Test Set:\n", "[!] Accuracy: 100.000%, F1: 100.000% | Took 31.26s\n" ] } ], "source": [ "xgb_full_adv = load_xgb(xgb_model_path_full)\n", "_ = score_both_sets(xgb_full_adv, oh_tokenizer_full)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Tabular Fully Connected Neural Network (aka MLP) with PyTorch" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Original Test Set:\n", "[*] Building DataLoader for 470129 samples\r" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[*] Predicting: 100%|██████████| 1837/1837 [00:10<00:00, 181.18it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[!] Accuracy: 99.930%, F1: 99.930% | Took 34.68s\n", "\n", "Adversarial Test Set:\n", "[*] Building DataLoader for 470129 samples\r" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[*] Predicting: 100%|██████████| 1837/1837 [00:10<00:00, 170.18it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[!] Accuracy: 89.179%, F1: 87.868% | Took 42.86s\n" ] } ], "source": [ "class SimpleMLP(nn.Module):\n", " def __init__(self, input_dim, output_dim, hidden_dim=[32], dropout=None):\n", " if isinstance(hidden_dim, int):\n", " hidden_dim = [hidden_dim]\n", " \n", " super().__init__()\n", " layers = []\n", " prev_dim = input_dim\n", " \n", " # Dynamically create hidden layers based on hidden_dim\n", " for h_dim in hidden_dim:\n", " layers.append(nn.Linear(prev_dim, h_dim))\n", " layers.append(nn.ReLU())\n", " if dropout:\n", " layers.append(nn.Dropout(dropout))\n", " prev_dim = h_dim\n", " \n", " layers.append(nn.Linear(prev_dim, output_dim))\n", " self.model = nn.Sequential(*layers)\n", " \n", " def forward(self, x):\n", " return self.model(x)\n", "\n", "\n", "mlp_orig = SimpleMLP(\n", " input_dim=VOCAB_SIZE,\n", " output_dim=1,\n", " hidden_dim=[64, 32],\n", " dropout=DROPOUT\n", ") # 264 K params\n", "\n", "\n", "mlp_orig = load_torch(mlp_model_path_orig, mlp_orig)\n", "_ = score_both_sets(mlp_orig, oh_tokenizer_orig)\n", "\n", "# see params\n", "# mlp_orig.state_dict()\n" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Original Test Set:\n", "[*] Building DataLoader for 470129 samples\r" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[*] Predicting: 100%|██████████| 1837/1837 [00:11<00:00, 158.51it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[!] Accuracy: 99.943%, F1: 99.943% | Took 34.79s\n", "\n", "Adversarial Test Set:\n", "[*] Building DataLoader for 470129 samples\r" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[*] Predicting: 100%|██████████| 1837/1837 [00:11<00:00, 157.97it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[!] Accuracy: 99.968%, F1: 99.968% | Took 44.49s\n" ] } ], "source": [ "mlp_adv = SimpleMLP(\n", " input_dim=VOCAB_SIZE,\n", " output_dim=1,\n", " hidden_dim=[64, 32],\n", " dropout=DROPOUT\n", ") # 264 K params\n", "\n", "mlp_adv = load_torch(mlp_model_path_adv, mlp_adv)\n", "_ = score_both_sets(mlp_adv, oh_tokenizer_adv)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Original Test Set:\n", "[*] Building DataLoader for 470129 samples\r" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[*] Predicting: 100%|██████████| 1837/1837 [00:11<00:00, 155.91it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[!] Accuracy: 99.999%, F1: 99.999% | Took 35.86s\n", "\n", "Adversarial Test Set:\n", "[*] Building DataLoader for 470129 samples\r" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[*] Predicting: 100%|██████████| 1837/1837 [00:10<00:00, 169.64it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[!] Accuracy: 99.999%, F1: 99.999% | Took 43.86s\n" ] } ], "source": [ "mlp_full_adv = SimpleMLP(\n", " input_dim=VOCAB_SIZE,\n", " output_dim=1,\n", " hidden_dim=[64, 32],\n", " dropout=DROPOUT\n", ") # 264 K params\n", "\n", "mlp_full_adv = load_torch(mlp_model_path_full, mlp_full_adv)\n", "_ = score_both_sets(mlp_full_adv, oh_tokenizer_full)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 1D Convolutional Neural Network with PyTorch" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Original Test Set:\n", "[*] Building DataLoader for 470129 samples\r" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[*] Predicting: 100%|██████████| 1837/1837 [00:02<00:00, 644.75it/s] \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[!] Accuracy: 97.619%, F1: 97.561% | Took 12.85s\n", "\n", "Adversarial Test Set:\n", "[*] Building DataLoader for 470129 samples\r" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[*] Predicting: 100%|██████████| 1837/1837 [00:02<00:00, 716.00it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[!] Accuracy: 77.520%, F1: 71.001% | Took 15.09s\n" ] } ], "source": [ "class CNN1DGroupedModel(nn.Module):\n", " def __init__(self, vocab_size, embed_dim, num_channels, kernel_sizes, mlp_hidden_dims, output_dim, dropout=None):\n", " super().__init__()\n", " \n", " self.embedding = nn.Embedding(vocab_size, embed_dim)\n", " self.grouped_convs = nn.ModuleList([nn.Conv1d(embed_dim, num_channels, kernel) for kernel in kernel_sizes])\n", " \n", " mlp_input_dim = num_channels * len(kernel_sizes)\n", " self.mlp = SimpleMLP(input_dim=mlp_input_dim, output_dim=output_dim, hidden_dim=mlp_hidden_dims, dropout=dropout)\n", "\n", " @staticmethod\n", " def conv_and_pool(x, conv):\n", " conv_out = conv(x)\n", " pooled = F.max_pool1d(conv_out, conv_out.size(2)).squeeze(2)\n", " return pooled\n", " \n", " def forward(self, x):\n", " x = self.embedding(x).transpose(1, 2)\n", " conv_outputs = [self.conv_and_pool(x, conv) for conv in self.grouped_convs]\n", "\n", " x = torch.cat(conv_outputs, dim=1)\n", " return self.mlp(x)\n", "\n", "\n", "cnn_orig = CNN1DGroupedModel(\n", " vocab_size=VOCAB_SIZE,\n", " embed_dim=EMBEDDED_DIM,\n", " num_channels=32,\n", " kernel_sizes=[2, 3, 4, 5],\n", " mlp_hidden_dims=[64, 32],\n", " output_dim=1,\n", " dropout=DROPOUT\n", ") # 301 K params\n", "\n", "cnn_orig = load_torch(cnn_model_path_orig, cnn_orig)\n", "_ = score_both_sets(cnn_orig, tokenizer_orig)" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Original Test Set:\n", "[*] Building DataLoader for 470129 samples\r" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[*] Predicting: 100%|██████████| 1837/1837 [00:03<00:00, 605.99it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[!] Accuracy: 99.992%, F1: 99.992% | Took 13.84s\n", "\n", "Adversarial Test Set:\n", "[*] Building DataLoader for 470129 samples\r" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[*] Predicting: 100%|██████████| 1837/1837 [00:02<00:00, 653.34it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[!] Accuracy: 99.995%, F1: 99.995% | Took 14.61s\n" ] } ], "source": [ "cnn_adv = CNN1DGroupedModel(\n", " vocab_size=VOCAB_SIZE,\n", " embed_dim=EMBEDDED_DIM,\n", " num_channels=32,\n", " kernel_sizes=[2, 3, 4, 5],\n", " mlp_hidden_dims=[64, 32],\n", " output_dim=1,\n", " dropout=DROPOUT\n", ") # 301 K params\n", "\n", "cnn_adv = load_torch(cnn_model_path_adv, cnn_adv)\n", "_ = score_both_sets(cnn_adv, tokenizer_adv)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Original Test Set:\n", "[*] Building DataLoader for 470129 samples\r" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[*] Predicting: 100%|██████████| 1837/1837 [00:02<00:00, 736.91it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[!] Accuracy: 99.999%, F1: 99.999% | Took 12.76s\n", "\n", "Adversarial Test Set:\n", "[*] Building DataLoader for 470129 samples\r" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[*] Predicting: 100%|██████████| 1837/1837 [00:02<00:00, 625.40it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[!] Accuracy: 99.999%, F1: 99.999% | Took 14.06s\n" ] } ], "source": [ "cnn_full_adv = CNN1DGroupedModel(\n", " vocab_size=VOCAB_SIZE,\n", " embed_dim=EMBEDDED_DIM,\n", " num_channels=32,\n", " kernel_sizes=[2, 3, 4, 5],\n", " mlp_hidden_dims=[64, 32],\n", " output_dim=1,\n", " dropout=DROPOUT\n", ") # 301 K params\n", "\n", "cnn_full_adv = load_torch(cnn_model_path_full, cnn_full_adv)\n", "_ = score_both_sets(cnn_full_adv, tokenizer_full)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Transformer Encoder for Classification" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "class PositionalEncoding(nn.Module):\n", " def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):\n", " super().__init__()\n", " self.dropout = nn.Dropout(p=dropout)\n", "\n", " # Initialize pe with shape [1, max_len, d_model] for broadcasting\n", " pe = torch.zeros(1, max_len, d_model)\n", " position = torch.arange(max_len).unsqueeze(1)\n", " div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))\n", " pe[0, :, 0::2] = torch.sin(position * div_term)\n", " pe[0, :, 1::2] = torch.cos(position * div_term)\n", " self.register_buffer('pe', pe)\n", "\n", " def forward(self, x):\n", " \"\"\"\n", " Args:\n", " x: Tensor, shape [batch_size, seq_len, embedding_dim]\n", " \"\"\"\n", " # Use broadcasting to add positional encoding\n", " x = x + self.pe[:, :x.size(1), :]\n", " return self.dropout(x)\n", "\n", "class BaseTransformerEncoder(nn.Module):\n", " def __init__(self, vocab_size, d_model, nhead, num_layers, dim_feedforward, max_len, dropout=None):\n", " super(BaseTransformerEncoder, self).__init__()\n", " \n", " assert d_model % nhead == 0, \"nheads must divide evenly into d_model\"\n", " self.embedding = nn.Embedding(vocab_size, d_model)\n", " self.pos_encoder = PositionalEncoding(d_model, dropout, max_len=max_len)\n", " encoder_layer = nn.TransformerEncoderLayer(\n", " d_model,\n", " nhead,\n", " dim_feedforward,\n", " dropout,\n", " norm_first=True,\n", " batch_first=True\n", " )\n", " self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers)\n", "\n", " def encode(self, src, src_mask=None, src_key_padding_mask=None):\n", " src = self.embedding(src) * math.sqrt(self.embedding.embedding_dim)\n", " src = self.pos_encoder(src)\n", " return self.transformer_encoder(src, mask=src_mask, src_key_padding_mask=src_key_padding_mask)\n", "\n", "class CLSTransformerEncoder(BaseTransformerEncoder):\n", " def __init__(self, mlp_hidden_dims, output_dim, *args, **kwargs):\n", " kwargs[\"max_len\"] += 1 # to account for CLS token\n", " super(CLSTransformerEncoder, self).__init__(*args, **kwargs)\n", " self.cls_token = nn.Parameter(torch.randn(1, 1, self.embedding.embedding_dim))\n", " self.decoder = SimpleMLP(input_dim=self.embedding.embedding_dim, output_dim=output_dim, hidden_dim=mlp_hidden_dims, dropout=kwargs.get(\"dropout\"))\n", "\n", " def forward(self, src, src_mask=None, src_key_padding_mask=None):\n", " # Embed the src token indices\n", " src = self.embedding(src) * math.sqrt(self.embedding.embedding_dim)\n", " \n", " # Repeat the cls_token for every item in the batch and concatenate it to src\n", " cls_tokens = self.cls_token.repeat(src.size(0), 1, 1)\n", " src = torch.cat([cls_tokens, src], dim=1)\n", " \n", " # Add positional encoding\n", " src = self.pos_encoder(src)\n", " \n", " # Pass through transformer encoder\n", " output = self.transformer_encoder(src, mask=src_mask, src_key_padding_mask=src_key_padding_mask)\n", " \n", " # Extract the encoding corresponding to the cls_token\n", " output = output[:, 0, :] # [B, E]\n", " \n", " return self.decoder(output)\n", "\n", "\n", "class MeanTransformerEncoder(BaseTransformerEncoder):\n", " def __init__(self, mlp_hidden_dims, output_dim, *args, **kwargs):\n", " super(MeanTransformerEncoder, self).__init__(*args, **kwargs)\n", " self.decoder = SimpleMLP(input_dim=self.embedding.embedding_dim, output_dim=output_dim, hidden_dim=mlp_hidden_dims, dropout=kwargs.get(\"dropout\"))\n", "\n", " def forward(self, src, src_mask=None, src_key_padding_mask=None):\n", " output = self.encode(src, src_mask, src_key_padding_mask)\n", " output = output.mean(dim=1)\n", " return self.decoder(output)" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Original Test Set:\n", "[*] Building DataLoader for 470129 samples\r" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/home/dtrizna/.local/lib/python3.10/site-packages/torch/nn/modules/transformer.py:286: UserWarning: enable_nested_tensor is True, but self.use_nested_tensor is False because encoder_layer.norm_first was True\n", " warnings.warn(f\"enable_nested_tensor is True, but self.use_nested_tensor is False because {why_not_sparsity_fast_path}\")\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[*] Predicting: 100%|██████████| 1837/1837 [00:04<00:00, 387.08it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[!] Accuracy: 99.696%, F1: 99.696% | Took 15.01s\n", "\n", "Adversarial Test Set:\n", "[*] Building DataLoader for 470129 samples\r" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[*] Predicting: 100%|██████████| 1837/1837 [00:07<00:00, 234.15it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[!] Accuracy: 86.388%, F1: 84.245% | Took 20.70s\n" ] } ], "source": [ "transformer_orig = CLSTransformerEncoder(\n", " vocab_size=VOCAB_SIZE,\n", " d_model=EMBEDDED_DIM,\n", " nhead=4,\n", " num_layers=2,\n", " dim_feedforward=128,\n", " max_len=128,\n", " dropout=DROPOUT,\n", " mlp_hidden_dims=[64, 32],\n", " output_dim=1\n", ") # 335 K params\n", "\n", "transformer_orig = load_torch(transformer_model_path_orig, transformer_orig)\n", "_ = score_both_sets(transformer_orig, tokenizer_orig)" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Original Test Set:\n", "[*] Building DataLoader for 470129 samples\r" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/home/dtrizna/.local/lib/python3.10/site-packages/torch/nn/modules/transformer.py:286: UserWarning: enable_nested_tensor is True, but self.use_nested_tensor is False because encoder_layer.norm_first was True\n", " warnings.warn(f\"enable_nested_tensor is True, but self.use_nested_tensor is False because {why_not_sparsity_fast_path}\")\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[*] Predicting: 100%|██████████| 1837/1837 [00:05<00:00, 307.70it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[!] Accuracy: 99.749%, F1: 99.750% | Took 17.17s\n", "\n", "Adversarial Test Set:\n", "[*] Building DataLoader for 470129 samples\r" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[*] Predicting: 100%|██████████| 1837/1837 [00:05<00:00, 330.66it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[!] Accuracy: 99.749%, F1: 99.750% | Took 18.48s\n" ] } ], "source": [ "transformer_adv = CLSTransformerEncoder(\n", " vocab_size=VOCAB_SIZE,\n", " d_model=EMBEDDED_DIM,\n", " nhead=4,\n", " num_layers=2,\n", " dim_feedforward=128,\n", " max_len=128,\n", " dropout=DROPOUT,\n", " mlp_hidden_dims=[64, 32],\n", " output_dim=1\n", ") # 335 K params\n", "\n", "transformer_adv = load_torch(transformer_model_path_adv, transformer_adv)\n", "_ = score_both_sets(transformer_adv, tokenizer_adv)" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Original Test Set:\n", "[*] Building DataLoader for 470129 samples\r" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/home/dtrizna/.local/lib/python3.10/site-packages/torch/nn/modules/transformer.py:286: UserWarning: enable_nested_tensor is True, but self.use_nested_tensor is False because encoder_layer.norm_first was True\n", " warnings.warn(f\"enable_nested_tensor is True, but self.use_nested_tensor is False because {why_not_sparsity_fast_path}\")\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[*] Predicting: 100%|██████████| 1837/1837 [00:05<00:00, 334.43it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[!] Accuracy: 99.997%, F1: 99.997% | Took 15.61s\n", "\n", "Adversarial Test Set:\n", "[*] Building DataLoader for 470129 samples\r" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[*] Predicting: 100%|██████████| 1837/1837 [00:05<00:00, 334.75it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[!] Accuracy: 99.997%, F1: 99.997% | Took 18.08s\n" ] } ], "source": [ "transformer_full_adv = CLSTransformerEncoder(\n", " vocab_size=VOCAB_SIZE,\n", " d_model=EMBEDDED_DIM,\n", " nhead=4,\n", " num_layers=2,\n", " dim_feedforward=128,\n", " max_len=128,\n", " dropout=DROPOUT,\n", " mlp_hidden_dims=[64, 32],\n", " output_dim=1\n", ") # 335 K params\n", "\n", "transformer_full_adv = load_torch(transformer_model_path_full, transformer_full_adv)\n", "_ = score_both_sets(transformer_full_adv, tokenizer_full)" ] } ], "metadata": { "kernelspec": { "display_name": "base", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 2 }