|
--- |
|
datasets: |
|
- esnli |
|
license: apache-2.0 |
|
language: |
|
- en |
|
metrics: |
|
- accuracy |
|
- f1 |
|
- precision |
|
- recall |
|
model-index: |
|
- name: backpack-gpt2-nli |
|
results: |
|
- task: |
|
name: Natural Language Inference |
|
type: text-classification |
|
dataset: |
|
name: e-SNLI |
|
type: esnli |
|
split: validation |
|
metrics: |
|
- name: Accuracy |
|
type: accuracy |
|
value: 0.9006299532615322 |
|
- name: F1 |
|
type: f1 |
|
value: 0.9004261302857443 |
|
- name: Precision |
|
type: precision |
|
value: 0.9004584180714215 |
|
- name: Recall |
|
type: recall |
|
value: 0.9004554220756779 |
|
- task: |
|
name: Natural Language Inference |
|
type: text-classification |
|
dataset: |
|
name: e-SNLI |
|
type: esnli |
|
split: test |
|
metrics: |
|
- name: Accuracy |
|
type: accuracy |
|
value: 0.8957654723127035 |
|
- name: F1 |
|
type: f1 |
|
value: 0.8954702227331482 |
|
- name: Precision |
|
type: precision |
|
value: 0.8954036872157838 |
|
- name: Recall |
|
type: recall |
|
value: 0.8955997285576146 |
|
pipeline_tag: text-classification |
|
tags: |
|
- Natural Language Inference |
|
- Sequence Classification |
|
- GPT2 |
|
- Backpack |
|
- ESNLI |
|
--- |
|
# Model Card for Backpack-GPT2-NLI |
|
This is a fine-tuned version of [backpack-gpt2](https://huggingface.co/stanfordnlp/backpack-gpt2) with a NLI classification head on the [esnli](https://huggingface.co/datasets/esnli) dataset. |
|
Results: |
|
- On Validation Set: |
|
- CrossEntropyLoss: 0.3168 |
|
- Accuracy: 0.9006 |
|
- F1: 0.9004 |
|
- On Test Set: |
|
- CrossEntropyLoss: 0.3277 |
|
- Accuracy: 0.8958 |
|
- F1: 0.8955 |
|
### Model Description |
|
- **Developed by:** [Erfan Moosavi Monazzah](https://huggingface.co/ErfanMoosaviMonazzah) |
|
- **Model type:** Sequence Classifier |
|
- **Language(s) (NLP):** English |
|
- **License:** apache-2.0 |
|
- **Finetuned from model [optional]:** [Backpack-GPT2](https://huggingface.co/stanfordnlp/backpack-gpt2) |
|
|
|
|
|
## How to Get Started with the Model |
|
```python |
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification |
|
|
|
tokenizer = AutoTokenizer.from_pretrained('gpt2') |
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
def tokenize_function(examples): |
|
concatenated_sentences = [f'{premise.strip(".")}. ^ {hypothesis.strip(".")}.' for premise, hypothesis in zip(examples['premise'], examples['hypothesis'])] |
|
|
|
tokenized_inputs = tokenizer( |
|
concatenated_sentences, |
|
padding="max_length", |
|
truncation=True, |
|
max_length=512, |
|
return_tensors="pt", |
|
) |
|
return tokenized_inputs |
|
|
|
model = AutoModelForSequenceClassification.from_pretrained('ErfanMoosaviMonazzah/backpack-gpt2-nli', trust_remote_code=True) |
|
model.eval() |
|
|
|
tokenized_sent = tokenize_function({ |
|
'premise':['A boy is jumping on skateboard in the middle of a red bridge.', |
|
'Two women who just had lunch hugging and saying goodbye.', |
|
'Children smiling and waving at camera'], |
|
'hypothesis':['The boy does a skateboarding trick.', |
|
'The friends have just met for the first time in 20 years, and have had a great time catching up.', |
|
'The kids are frowning'] |
|
}) |
|
model.predict(input_ids=tokenized_sent['input_ids'], attention_mask=tokenized_sent['attention_mask']) |
|
``` |
|
### Training hyperparameters |
|
The following hyperparameters were used during training: |
|
- learning_rate: 5e-5 |
|
- train_batch_size: 64 |
|
- eval_batch_size: 64 |
|
- seed: 2023 |
|
- optimizer: Adam with betas=(0.9,0.999) and epsilon=1e-08 |
|
- lr_scheduler_type: linear |
|
- lr_scheduler_warmup_ratio: 0 |
|
- num_epochs: 3 |
|
### Training results |
|
|Step |Training Loss|Validation Loss|Precision|Recall |F1 |Accuracy| |
|
|------------|-------------|---------------|---------|--------|--------|--------| |
|
|512 |0.614900 |0.463713 |0.826792 |0.824639|0.825133|0.824731| |
|
|1024 |0.503300 |0.431796 |0.844831 |0.839414|0.839980|0.839565| |
|
|1536 |0.475600 |0.400771 |0.848741 |0.847009|0.846287|0.847795| |
|
|2048 |0.455900 |0.375981 |0.859064 |0.857357|0.857749|0.857448| |
|
|2560 |0.440400 |0.365537 |0.862000 |0.862078|0.861917|0.862426| |
|
|3072 |0.433100 |0.365180 |0.864717 |0.859693|0.860237|0.859785| |
|
|3584 |0.425100 |0.346340 |0.872312 |0.870635|0.870865|0.870961| |
|
|4096 |0.413300 |0.343761 |0.873606 |0.873046|0.873174|0.873298| |
|
|4608 |0.412000 |0.344890 |0.882609 |0.882120|0.882255|0.882341| |
|
|5120 |0.402600 |0.336744 |0.876463 |0.875629|0.875827|0.875737| |
|
|5632 |0.390600 |0.323248 |0.882598 |0.880779|0.881129|0.880817| |
|
|6144 |0.388300 |0.338029 |0.877255 |0.877041|0.877126|0.877261| |
|
|6656 |0.390800 |0.333301 |0.876357 |0.876362|0.875965|0.876753| |
|
|7168 |0.383800 |0.328297 |0.883593 |0.883675|0.883629|0.883967| |
|
|7680 |0.380800 |0.331854 |0.882362 |0.880373|0.880764|0.880512| |
|
|8192 |0.368400 |0.323076 |0.881730 |0.881378|0.881419|0.881528| |
|
|8704 |0.367000 |0.313959 |0.889204 |0.889047|0.889053|0.889352| |
|
|9216 |0.315600 |0.333637 |0.885518 |0.883965|0.884266|0.883967| |
|
|9728 |0.303100 |0.319416 |0.888667 |0.888092|0.888256|0.888234| |
|
|10240 |0.307200 |0.317827 |0.887575 |0.887647|0.887418|0.888031| |
|
|10752 |0.300100 |0.311810 |0.890908 |0.890827|0.890747|0.891181| |
|
|11264 |0.303400 |0.311010 |0.889871 |0.887939|0.888309|0.887929| |
|
|11776 |0.300500 |0.309282 |0.891041 |0.889819|0.890077|0.889860| |
|
|12288 |0.303600 |0.326918 |0.891272 |0.891250|0.890942|0.891689| |
|
|12800 |0.300300 |0.301688 |0.894516 |0.894619|0.894481|0.894940| |
|
|13312 |0.302200 |0.302173 |0.896441 |0.896527|0.896462|0.896769| |
|
|13824 |0.299800 |0.293489 |0.895047 |0.895172|0.895084|0.895448| |
|
|14336 |0.294600 |0.297645 |0.895865 |0.896012|0.895886|0.896261| |
|
|14848 |0.296700 |0.300751 |0.895277 |0.895401|0.895304|0.895651| |
|
|15360 |0.293100 |0.293049 |0.896855 |0.896705|0.896757|0.896871| |
|
|15872 |0.293600 |0.294201 |0.895933 |0.895557|0.895624|0.895651| |
|
|16384 |0.290100 |0.289367 |0.897847 |0.897889|0.897840|0.898090| |
|
|16896 |0.293600 |0.283990 |0.898889 |0.898724|0.898789|0.898903| |
|
|17408 |0.285800 |0.308257 |0.898250 |0.898102|0.898162|0.898293| |
|
|17920 |0.252400 |0.327164 |0.898860 |0.898807|0.898831|0.899004| |
|
|18432 |0.219500 |0.315286 |0.898877 |0.898835|0.898831|0.899004| |
|
|18944 |0.217900 |0.312738 |0.898857 |0.898958|0.898886|0.899207| |
|
|19456 |0.186400 |0.320669 |0.899252 |0.899166|0.899194|0.899411| |
|
|19968 |0.199000 |0.316840 |0.900458 |0.900455|0.900426|0.900630| |
|
|
|
|
|
## Model Card Authors |
|
|
|
[Erfan Moosavi Monazzah](https://huggingface.co/ErfanMoosaviMonazzah) |