agaresd's picture
Update README.md
e18dfad verified
metadata
license: mit
datasets:
  - li2017dailydialog/daily_dialog
language:
  - en
pipeline_tag: text-classification

This is the repo for Gen AI final project

Transformer with Emotion Classification

Overview

This is a Transformer-based model designed for emotion classification and dialogue act recognition on the DailyDialog dataset. It processes multi-turn dialogues to predict emotional states and communication intentions. A Stacked Autoencoder (SAE) is included to regularize node usage, encouraging sparsity in feature representations.

While the model successfully predicts dialogue acts, it faces challenges in emotion classification, often outputting binary labels (0 or 1) due to imbalanced data or other limitations.


Model Details

Model Architecture

  • Transformer Encoder: A standard Transformer encoder serves as the backbone for extracting contextual features from dialogues.
  • Batch Normalization: Applied to normalize extracted features.
  • Dropout: Used to reduce overfitting.
  • Stacked Autoencoder (SAE): Regularizes feature representations by encouraging sparsity, adding KL divergence loss during training.
  • Classification Heads:
    • Dialogue Act Classifier: Predicts communication intentions (e.g., inform, question).
    • Emotion Classifier: Predicts one of the annotated emotions (e.g., happiness, sadness, anger).

Input

  • sentence: Single word or sentence

Output

  • act_output: Predicted dialogue act class.
  • emotion_output: Predicted emotion class.

Dataset: DailyDialog

The model is trained and evaluated on the DailyDialog dataset.

Dataset Features


Usage

Training

  1. Dataset Preparation: Use tokenizers to preprocess the DailyDialog dataset into input_ids and attention_mask.
  2. Training Steps:
    • Forward pass the input through the model.
    • Compute cross-entropy loss for the dialogue act and emotion classifiers.
    • Add KL divergence loss for SAE regularization.
    • Backpropagate and update parameters.

Inference

  • Input: Tokenized text sequences and attention masks.
  • Output: Predicted dialogue acts and emotion classes.

Limitations

  • Emotion Classification: The model struggles to predict diverse emotional states, often outputting binary values (0 or 1).
  • Imbalanced Dataset: Emotion labels in the DailyDialog dataset are not evenly distributed, which impacts model performance.
  • Limited Domain: The dataset is focused on daily conversations, so the model may not generalize well to other dialogue contexts.

Citation

If you use this model or the DailyDialog dataset, please cite:

@inproceedings{li2017dailydialog,
  title={DailyDialog: A Manually Labelled Multi-turn Dialogue Dataset},
  author={Li, Yanran and others},
  booktitle={Proceedings of the International Joint Conference on Artificial Intelligence (IJCAI)},
  year={2017}
}

## Info
License: Mit

Original code: https://github.com/hyunwoongko/transformer

My version: https://github.com/Agaresd47/transformer_SAE

Data: Source: https://huggingface.co/datasets/li2017dailydialog/daily_dialog



## Usage
```python
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import torch.nn.functional as F

# Load the model and tokenizer
model = AutoModelForSequenceClassification.from_pretrained("agaresd/your-model-name")
tokenizer = AutoTokenizer.from_pretrained("agaresd/your-model-name")

# Define the label mapping
label_mapping = {
    0: "no emotion",
    1: "anger ",
    2: "disgust ",
    3: "fear ",
    4: "Emotion: Happy",
    5: "Emotion: Sad",
    6: "Emotion: surprise"
}

# Input text
input_text = "happy"

# Tokenize and get model outputs
inputs = tokenizer(input_text, return_tensors="pt")
outputs = model(**inputs)

# Get logits, apply softmax, and find the predicted class
logits = outputs.logits
probabilities = F.softmax(logits, dim=-1)
predicted_class = torch.argmax(probabilities, dim=-1).item()

# Map the predicted class to a word
predicted_label = label_mapping[predicted_class]
print(f"Input: {input_text}")
print(f"Predicted Label: {predicted_label}")