A5hbr1ng3r commited on
Commit
48cb8ea
·
verified ·
1 Parent(s): 7eb2450

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +23 -0
README.md CHANGED
@@ -1,3 +1,26 @@
1
  ---
2
  license: mit
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  license: mit
3
  ---
4
+
5
+ ```python
6
+ import torch
7
+ from transformers.models.bert import BertTokenizer, BertForSequenceClassification
8
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
9
+
10
+
11
+ # Load model architecture from COLD and load fine-tuned params.
12
+ model_name = "thu-coai/roberta-base-cold"
13
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
14
+ model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2)
15
+ model_path = "finetuned_cold_LoL.pth"
16
+ model.load_state_dict(torch.load(model_path))
17
+
18
+
19
+ # Demo for toxicity detection
20
+ texts = ['狠狠地导', '卡了哟', 'gala有卡莎皮肤,你们这些小黑子有吗?', '早改了,改成回血了']
21
+ model_input = tokenizer(texts, return_tensors="pt", padding=True)
22
+ model_output = model(**model_input, return_dict=False)
23
+ prediction = torch.argmax(model_output[0].cpu(), dim=-1)
24
+ prediction = [p.item() for p in prediction]
25
+ # prediction = [1, 0, 1, 0] # 1 for toxic, 0 for non-toxic
26
+ ```