|
--- |
|
library_name: transformers |
|
license: apache-2.0 |
|
language: |
|
- en |
|
base_model: |
|
- sentence-transformers/all-distilroberta-v1 |
|
--- |
|
|
|
# all-distilroberta-ce-esci |
|
|
|
<!-- Provide a quick summary of what the model is/does. --> |
|
|
|
This is a cross-encoder model optimized for e-commerce text classification tasks. |
|
|
|
## Model Details |
|
|
|
### Model Description |
|
|
|
<!-- Provide a longer summary of what this model is. --> |
|
|
|
This is a fine-tuned cross-encoder model based on all-distilroberta-v1, trained on an [e-commerce dataset](https://github.com/amazon-science/esci-data/tree/main/shopping_queries_dataset) of query-product pairs. The model predicts relevance classes in the ESCI (Exact, Substitute, Complementary, Irrelevant) framework by capturing the relationship of the input text and class labels, which can be used for multi-class classification tasks or more complex downstream tasks. |
|
|
|
- **Developed by:** Sarah Lawlis / DASC Practicum Team 12 |
|
- **Shared by:** University of Arkansas Data Science Practicum Team 12 |
|
- **Model type:** Sequence Classification (Cross-Encoder) |
|
- **Language(s) (NLP):** English |
|
- **License:** apache-2.0 |
|
- **Finetuned from model:** sentence-transformers/all-distilroberta-v1 |
|
|
|
### Model Sources |
|
|
|
<!-- Provide the basic links for the model. --> |
|
|
|
- **Repository:** [sllawlis/distilroberta-ce-esci](https://huggingface.co/sllawlis/distilroberta-ce-esci) |
|
|
|
## Uses |
|
|
|
<!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. --> |
|
|
|
### Direct Use |
|
|
|
<!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. --> |
|
|
|
This model is designed for multi-class product classification within the ESCI framework. The model directly predicts one of the ESCI labels for a given query-product pair. This task is the foundation for downstream use cases. |
|
|
|
### Downstream Use |
|
|
|
<!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app --> |
|
|
|
The model's multi-class predictions can be used in the following downstream tasks: |
|
1. Ranking Systems: |
|
* Combine the model's predictions with bi-encoders for a two-stage ranking pipeline: |
|
* First Stage (Bi-Encoders): Generate candidate products efficiently by retrieving embeddings of query and product titles |
|
* Second Stage (Cross-Encoders): Re-rank the candidates using fine-grained ESCI label predictions for better accuracy |
|
|
|
2. Product Substitute Identification: |
|
* Use the Substitute label from the model to identify products that can replace one another |
|
|
|
## Bias, Risks, and Limitations |
|
|
|
<!-- This section is meant to convey both technical and sociotechnical limitations. --> |
|
|
|
* Bias: Due to heavy imbalance in ESCI labels in the training data, this model's predictions may skew to predicting more Exact labels. |
|
* Limitations: This model is domain-specific to e-commerce data and may not generalize well to other domains. This model is optimized for the English language and may perform poorly with non-English data. Cross-encoders are computationally expensive for large-scale applications, there may be difficulty implementing this model for real-time inference. |
|
|
|
## How to Get Started with the Model |
|
|
|
Use the code below to get started with the model. |
|
|
|
```python |
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification |
|
|
|
# Load the tokenizer and model |
|
model_name = "sllawlis/distilroberta-ce-esci" |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
model = AutoModelForSequenceClassification.from_pretrained(model_name) |
|
``` |
|
|
|
## Usage (Multi-class Classification Example) |
|
|
|
Below is a quick usage example of this model. |
|
|
|
```python |
|
# Example query-product pair |
|
query = "wireless headphones" |
|
product = "Noise-cancelling wireless headphones with long battery life" |
|
|
|
# Tokenize inputs |
|
inputs = tokenizer( |
|
query, |
|
product, |
|
truncation=True, |
|
padding=True, |
|
return_tensors="pt" |
|
) |
|
|
|
# Predict relevance |
|
outputs = model(**inputs) |
|
predicted_class = outputs.logits.argmax(dim=1).item() |
|
print(f"Predicted Class: {predicted_class}") |
|
``` |
|
|
|
|
|
## Training Details |
|
|
|
### Pre-training |
|
|
|
The model uses the pretrained [all-distilroberta-v1](https://huggingface.co/sentence-transformers/all-distilroberta-v1). |
|
|
|
### Fine-tuning |
|
|
|
The model is fine-tuned for multi-class relevance classification based on the ESCI framework. The fine-tuning process involves an input of query-product pairs, and an objective of classification using cross entropy loss to align predicted class probabilities with true labels. |
|
|
|
### Hyperparameters |
|
|
|
Training was performed on a Tesla V100-PCIE-32GB GPU with a batch size of 32 over 3 epochs. The learning rate was set to 5e-5 and optimized using the AdamW optimizer, with 10% of the total training steps allocated for warm-up. Input sequences were padded to a max length of 512 tokens. Validation was conducted every ~10% of an epoch, and micro F1 score and accuracy were used to evaluate performance. |
|
|
|
### Training Data |
|
|
|
<!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. --> |
|
|
|
| Dataset | Paper | Number of training tuples | |
|
|--------------------------------------------------------|:----------------------------------------:|:--------------------------:| |
|
| [Amazon Shopping Queries Dataset](https://github.com/amazon-science/esci-data/tree/main/shopping_queries_dataset) | [paper](https://arxiv.org/pdf/2206.06588) | 1,253,756 | |
|
|
|
## Model Card Authors |
|
|
|
[Sarah Lawlis](https://www.linkedin.com/in/sarah-lawlis/) |
|
|
|
## Model Card Contact |
|
|
|
[email protected] |