go-emotions-polish-gpt2-small-v0.0.1

This model is a fine-tuned version of sdadas/polish-gpt2-small on the machine translated google-research-datasets/go_emotions dataset. It achieves the following results on the evaluation set: Every list contains results for threshold [0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]

  • Loss: 0.2123
  • Hamming Accuracy: [0.9556733317272946, 0.9592094848057267, 0.9614034483945348, 0.9631328079292425, 0.96438895963107, 0.9652665450665933, 0.9654558281997453]
  • F1 Macro: [0.4929423141608899, 0.49513962905111936, 0.48929787051637963, 0.4797491530618914, 0.4647116651469601, 0.44570651699340547, 0.40534073938214327]
  • Precision Macro: [0.45136216505815147, 0.4817190614473335, 0.5085113721947904, 0.5297915362044485, 0.5555560780209522, 0.5895901759803143, 0.6380967706137726]
  • Recall Macro: [0.551217261404107, 0.5185202616828293, 0.4824209900009455, 0.4506401023339946, 0.4144524017062031, 0.3745523616734555, 0.3152333101205886]

Try here: spaces/nie3e/polish-emotions

Model description

Trained from sdadas/polish-gpt2-small

Intended uses & limitations

Detecting emotions described in paper 2005.00547

Labels:

0: admiration
1: amusement
2: anger
3: annoyance
4: approval
5: caring
6: confusion
7: curiosity
8: desire
9: disappointment
10: disapproval
11: disgust
12: embarrassment
13: excitement
14: fear
15: gratitude
16: grief
17: joy
18: love
19: nervousness
20: optimism
21: pride
22: realization
23: relief
24: remorse
25: sadness
26: surprise
27: neutral

How to use:

from transformers import AutoModelForSequenceClassification, AutoTokenizer
import torch

device = "cuda" if torch.cuda.is_available() else "cpu"
checkpoint = "nie3e/go-emotions-polish-gpt2-small-v0.0.1"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model = AutoModelForSequenceClassification.from_pretrained(
    checkpoint, problem_type="multi_label_classification"
).to(device)

text = "To jest model wykrywajฤ…cy super emocje w tekล›cie! :D"

input_ids = tokenizer(text, return_tensors="pt").to(device)
logits = model(**input_ids)["logits"].to("cpu")

threshold = 0.3
predicted_class_ids = torch.arange(
    0, logits.shape[-1]
)[torch.sigmoid(logits).squeeze(dim=0) > threshold]

percent = torch.sigmoid(logits).squeeze(dim=0)

id2class = model.config.id2label
print([id2class[c] for c in predicted_class_ids.tolist()])
print({id2class[i]: f"{(p*100):.2f}%" for i, p in enumerate(percent.tolist())})
['joy']
{'admiration': '17.75%', 'amusement': '11.22%', 'anger': '0.07%', 'annoyance': '0.36%', 'approval': '6.63%', 'caring': '0.84%', 'confusion': '0.22%', 'curiosity': '0.58%', 'desire': '0.40%', 'disappointment': '1.29%', 'disapproval': '0.26%', 'disgust': '0.11%', 'embarrassment': '0.08%', 'excitement': '25.88%', 'fear': '0.54%', 'gratitude': '0.41%', 'grief': '0.90%', 'joy': '63.62%', 'love': '11.74%', 'nervousness': '0.08%', 'optimism': '1.98%', 'pride': '0.03%', 'realization': '1.19%', 'relief': '0.53%', 'remorse': '0.02%', 'sadness': '0.75%', 'surprise': '0.58%', 'neutral': '8.93%'}

or using pipeline:

from transformers import pipeline

# Set the model, tokenizer, device, text as above

pipe = pipeline(
    task="text-classification",
    model=model,
    tokenizer=tokenizer,
    top_k=-1,
    device=device
)

result = pipe(text)
print(result)
[[{'label': 'joy', 'score': 0.6362035274505615}, {'label': 'excitement', 'score': 0.2588024437427521}, {'label': 'admiration', 'score': 0.17747776210308075}, {'label': 'love', 'score': 0.11739460378885269}, {'label': 'amusement', 'score': 0.11221607774496078}, {'label': 'neutral', 'score': 0.08927429467439651}, {'label': 'approval', 'score': 0.0662560984492302}, {'label': 'optimism', 'score': 0.019809801131486893}, {'label': 'disappointment', 'score': 0.012886008247733116}, {'label': 'realization', 'score': 0.011940046213567257}, {'label': 'grief', 'score': 0.009018097072839737}, {'label': 'caring', 'score': 0.008446046151220798}, {'label': 'sadness', 'score': 0.007472767494618893}, {'label': 'curiosity', 'score': 0.0058141243644058704}, {'label': 'surprise', 'score': 0.005764781963080168}, {'label': 'fear', 'score': 0.00539048807695508}, {'label': 'relief', 'score': 0.005273739341646433}, {'label': 'gratitude', 'score': 0.004061913583427668}, {'label': 'desire', 'score': 0.003967347089201212}, {'label': 'annoyance', 'score': 0.0036265170201659203}, {'label': 'disapproval', 'score': 0.0026028596330434084}, {'label': 'confusion', 'score': 0.0022179142106324434}, {'label': 'disgust', 'score': 0.0011114622466266155}, {'label': 'embarrassment', 'score': 0.0007856030715629458}, {'label': 'nervousness', 'score': 0.0007625268190167844}, {'label': 'anger', 'score': 0.0007304779137484729}, {'label': 'pride', 'score': 0.0003317077935207635}]]

Training and evaluation data

Dataset: google-research-datasets/go_emotions Preprocessing:

  • dropping rows that exceeded 256 tokens (GPT2 Tokenizer) and having 5 or more digits in a row in the text
  • machine translation into Polish
  • removing 80% rows that have only neutral label

Training procedure

Training hyperparameters

The following hyperparameters were used during training:

  • learning_rate: 2e-05
  • train_batch_size: 8
  • eval_batch_size: 2
  • seed: 42
  • gradient_accumulation_steps: 8
  • total_train_batch_size: 64
  • optimizer: Use OptimizerNames.ADAMW_TORCH with betas=(0.9,0.999) and epsilon=1e-08 and optimizer_args=No additional optimizer arguments
  • lr_scheduler_type: linear
  • lr_scheduler_warmup_ratio: 0.1
  • num_epochs: 10
Trainer using class weights:
class WeightedTrainer(Trainer):
    def __init__(self, class_weights=None, **kwargs):
        super().__init__(**kwargs)
        self.class_weights = class_weights

    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        labels = inputs.pop("labels")
        outputs = model(**inputs)
        logits = outputs.logits

        loss_fct = torch.nn.BCEWithLogitsLoss(pos_weight=self.class_weights)
        loss = loss_fct(logits, labels.float())

        return (loss, outputs) if return_outputs else loss
Class weights:
df = tokenized_dataset["train"].to_pandas()
label_counts = df["labels_h1"].explode().value_counts().sort_index()
total_samples = len(df)
class_weights = [(total_samples - count) / count for count in label_counts]
class_weights = np.clip(class_weights, 0.1, 10.0)
Metrics computation:
import numpy as np
from sklearn.metrics import f1_score, precision_score, recall_score, accuracy_score

def sigmoid(x):
    return 1/(1 + np.exp(-x))


def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    probabilities = sigmoid(predictions)
    thresholds = np.arange(0.3, 0.91, 0.1)

    computed_metrics = {
        "hamming_accuracy": [],
        "f1_macro": [],
        "precision_macro": [],
        "recall_macro": []
    }

    for th in thresholds:
        binary_preds = (probabilities > th).astype(int)

        # Hamming Accuracy (for multi-label)
        hamming_acc = 1 - np.mean(binary_preds != labels)

        # Macro-averaged F1/Precision/Recall
        f1 = f1_score(labels, binary_preds, average="macro", zero_division=0)
        precision = precision_score(labels, binary_preds, average="macro",
                                    zero_division=0)
        recall = recall_score(labels, binary_preds, average="macro",
                              zero_division=0)

        computed_metrics["hamming_accuracy"].append(hamming_acc)
        computed_metrics["f1_macro"].append(f1)
        computed_metrics["precision_macro"].append(precision)
        computed_metrics["recall_macro"].append(recall)

    return computed_metrics

Training results

Every list contains results for threshold [0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]

Results:
Epoch Training Loss Validation Loss Hamming Accuracy F1 Macro Precision Macro Recall Macro
1 0.27 0.1772 [0.9512768007708986, 0.959080428124032, 0.9628144681143959, 0.9640103933647658, 0.9644922049764256, 0.9635629968682246, 0.9613690332794164] [0.46241863342821793, 0.4532072267866339, 0.4326439737447529, 0.3910675952043853, 0.34917671255121363, 0.2895274759086305, 0.17584749438485497] [0.4461301409377742, 0.5054615942184383, 0.5451358075233073, 0.5833130459387715, 0.612234133178668, 0.6415091314766198, 0.5974110920754876] [0.5228623671145524, 0.44986334764237274, 0.3906724180081668, 0.32626227826074566, 0.27496335348059725, 0.21568979328206703, 0.12657202290169853]
2 0.1671 0.1609 [0.9546752933888564, 0.9605000516226727, 0.9636834497711395, 0.9650600543758819, 0.9657311491206938, 0.9651977148363561, 0.9638813366830712] [0.47181939799853245, 0.4773149021343449, 0.47247182216640654, 0.4451366460170004, 0.42366333538332507, 0.37403397926368737, 0.30313475002048484] [0.441959190629794, 0.5081107432668841, 0.561701577155539, 0.5801413709101412, 0.6462220107453076, 0.6764110777613557, 0.7061705442835556] [0.580901954973524, 0.5214683734393482, 0.47275669773690804, 0.40899548004178854, 0.36579701720117563, 0.3034817689033619, 0.23056933111221325]
3 0.1399 0.1606 [0.9497109130330041, 0.9569208796503424, 0.9613518257218571, 0.9638211102316138, 0.965249337509034, 0.9654988470936435, 0.9639931858072065] [0.4749602125981087, 0.4873423290936483, 0.48338025379077687, 0.4701566587566043, 0.4498896201717952, 0.41177613984294953, 0.32623322721112374] [0.41117966940106204, 0.45764702886929787, 0.5046222714256859, 0.5530253088371339, 0.6005644340448025, 0.6760814655629932, 0.6800566831000542] [0.6222726145233521, 0.5734770801943697, 0.5113684601796389, 0.4550331179550861, 0.4044620988185089, 0.34071151157255797, 0.25092165205453]
4 0.1176 0.1669 [0.95028736621124, 0.9570241249956981, 0.9612141652613828, 0.9638555253467322, 0.9651202808273394, 0.9657827717933717, 0.9645008087552053] [0.46776823075218205, 0.4770370632194352, 0.48170399873205805, 0.4815610729662237, 0.46635824100838313, 0.4346021274443605, 0.35739954616385866] [0.4026523166950933, 0.44682680882818887, 0.4916683449781562, 0.5367728364431649, 0.575993305278815, 0.6472053281375478, 0.6693600674340754] [0.6123919277944883, 0.557746427766825, 0.5176178038314311, 0.47937515940889003, 0.4322437215959729, 0.36783678520926777, 0.27372379016591925]
5 0.0992 0.1735 [0.9527996696148948, 0.9587878996455244, 0.9617906184396187, 0.9635802044257838, 0.9652063186151357, 0.9660236775992016, 0.9651891110575765] [0.48529687411311123, 0.49524580727616085, 0.4889009921348519, 0.476286545728211, 0.4606138233621902, 0.4360905909528836, 0.37325003123769596] [0.43090457512637126, 0.4758805613570571, 0.5059212751069277, 0.5323418850235423, 0.5778588494150572, 0.6303798411419806, 0.7050763694911789] [0.5871641620419744, 0.5468805735159331, 0.5029133763134643, 0.46169826597462366, 0.4141503796261941, 0.3657398096031768, 0.28558429402703295]
6 0.083 0.1841 [0.9520941597549644, 0.9572736345803077, 0.9603795987197578, 0.9626337887600234, 0.9642771105069347, 0.9654042055270675, 0.9656279037753381] [0.4848110691063849, 0.4902428690172167, 0.48618324094573484, 0.478250083701593, 0.4677017272145706, 0.44389902987164165, 0.39728758295841066] [0.4282820729807268, 0.4617066856103255, 0.48889885714082243, 0.515822475491864, 0.548112314645949, 0.5864861700775268, 0.642991780443598] [0.5872655526119982, 0.5489591883468488, 0.5083264155214335, 0.47033662001581716, 0.43490169281095825, 0.384219870731881, 0.3108363940394387]
7 0.0715 0.1922 [0.9541934817771965, 0.958366314485322, 0.960921636782875, 0.9629607323536498, 0.964363148294731, 0.9655160546512028, 0.9651460921636783] [0.49327999654639215, 0.4942623403104429, 0.48696556713303135, 0.4804445324053421, 0.46939295694544736, 0.4475119371916684, 0.3881489924589002] [0.4442084347600441, 0.47508959657212724, 0.5001365624277175, 0.5315837721216776, 0.5635172346685284, 0.6060793900700022, 0.6547980634968182] [0.5749401619525247, 0.5335821079528872, 0.4937842962870135, 0.45744546851621865, 0.4230868747861999, 0.37702257236227427, 0.29712029562165154]
8 0.0609 0.2025 [0.9547011047251953, 0.9585469938396944, 0.961377637058196, 0.9629951474687682, 0.9646040541005609, 0.9654644319785249, 0.9653095639604914] [0.4929084055191263, 0.4966642184262872, 0.49509852255405307, 0.47826643819598524, 0.4664438311012158, 0.4433651185391054, 0.38896780252101004] [0.45156391170754395, 0.48340644816595113, 0.5119596021038765, 0.5318185543351281, 0.5669432431065436, 0.6023885415673448, 0.640319754629273] [0.5590029919029157, 0.5266729465629012, 0.49473318351169737, 0.45252148670174985, 0.4154957108440041, 0.37123594052789305, 0.2964469621468265]
9 0.0524 0.2099 [0.9549592180885845, 0.9585986165123722, 0.9611195236948068, 0.9629435247960905, 0.9642771105069347, 0.9651030732697801, 0.9651805072787969] [0.4920579362491578, 0.48948143585573084, 0.48373280918321976, 0.4765803308742461, 0.464925139967501, 0.44204098321531043, 0.3994121381701787] [0.4481765728464743, 0.4727514944675648, 0.49823474126036366, 0.5251829103863094, 0.5535803816229916, 0.5863930625014495, 0.6280255957220133] [0.5567520239062707, 0.5174123200187039, 0.4811964159610151, 0.4494146051029743, 0.41590442650848564, 0.37120008498454643, 0.30929707478684687]
10 0.0479 0.2123 [0.9556733317272946, 0.9592094848057267, 0.9614034483945348, 0.9631328079292425, 0.96438895963107, 0.9652665450665933, 0.9654558281997453] [0.4929423141608899, 0.49513962905111936, 0.48929787051637963, 0.4797491530618914, 0.4647116651469601, 0.44570651699340547, 0.40534073938214327] [0.45136216505815147, 0.4817190614473335, 0.5085113721947904, 0.5297915362044485, 0.5555560780209522, 0.5895901759803143, 0.6380967706137726] [0.551217261404107, 0.5185202616828293, 0.4824209900009455, 0.4506401023339946, 0.4144524017062031, 0.3745523616734555, 0.3152333101205886]

Framework versions

  • Transformers 4.48.3
  • Pytorch 2.5.1+cu124
  • Datasets 3.2.0
  • Tokenizers 0.21.0
Downloads last month
51
Safetensors
Model size
126M params
Tensor type
F32
ยท
Inference Providers NEW
This model is not currently available via any of the supported Inference Providers.

Model tree for nie3e/go-emotions-polish-gpt2-small-v0.0.1

Finetuned
(3)
this model

Space using nie3e/go-emotions-polish-gpt2-small-v0.0.1 1