import torch
from transformers.models.bert import BertTokenizer, BertForSequenceClassification
from transformers import AutoTokenizer, AutoModelForSequenceClassification


# Load model architecture from COLD and load fine-tuned params.
model_name = "thu-coai/roberta-base-cold"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2)
model_path = "finetuned_cold_LoL.pth" # Could be downloaded in this repo.
model.load_state_dict(torch.load(model_path))


# Demo for toxicity detection
texts = ['狠狠地导', '卡了哟', 'gala有卡莎皮肤,你们这些小黑子有吗?', '早改了,改成回血了']
model_input = tokenizer(texts, return_tensors="pt", padding=True)
model_output = model(**model_input, return_dict=False)
prediction = torch.argmax(model_output[0].cpu(), dim=-1)
prediction = [p.item() for p in prediction]
# prediction = [1, 0, 1, 0] # 1 for toxic, 0 for non-toxic
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model is not currently available via any of the supported Inference Providers.
The model cannot be deployed to the HF Inference API: The model has no library tag.