|
--- |
|
license: mit |
|
language: |
|
- en |
|
- zh |
|
--- |
|
|
|
```python |
|
import torch |
|
from transformers.models.bert import BertTokenizer, BertForSequenceClassification |
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification |
|
|
|
|
|
# Load model architecture from COLD and load fine-tuned params. |
|
model_name = "thu-coai/roberta-base-cold" |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2) |
|
model_path = "finetuned_cold_LoL.pth" # Could be downloaded in this repo. |
|
model.load_state_dict(torch.load(model_path)) |
|
|
|
|
|
# Demo for toxicity detection |
|
texts = ['狠狠地导', '卡了哟', 'gala有卡莎皮肤,你们这些小黑子有吗?', '早改了,改成回血了'] |
|
model_input = tokenizer(texts, return_tensors="pt", padding=True) |
|
model_output = model(**model_input, return_dict=False) |
|
prediction = torch.argmax(model_output[0].cpu(), dim=-1) |
|
prediction = [p.item() for p in prediction] |
|
# prediction = [1, 0, 1, 0] # 1 for toxic, 0 for non-toxic |
|
``` |