File size: 900 Bytes
f24c629
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
from transformers import BartForSequenceClassification, BartTokenizer
import gradio as grad

model_name = "facebook/bart-large-mnli"
bart_tokenizer = BartTokenizer.from_pretrained(model_name)
model = BartForSequenceClassification.from_pretrained(model_name)

def classify(text, label):
    token_ids = bart_tokenizer.encode(text, label, return_tensors="pt")
    token_logits = model(token_ids)[0]
    entail_contra_token_logits = token_logits[:, [0, 2]]
    probabilities = entail_contra_token_logits.softmax(dim=1)
    response = probabilities[:, 1].item() * 100
    return response

in_text = grad.Textbox(lines=1, label="English", placeholder="Text to be classified")
in_labels = grad.Textbox(lines=1, label="Label", placeholder="Input a label")
out = grad.Textbox(lines=1, label="Probability of label being true is ")

grad.Interface(classify, inputs=[in_text, in_labels], outputs=[out]).launch()