Spaces:
Runtime error
Runtime error
File size: 900 Bytes
f24c629 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 |
from transformers import BartForSequenceClassification, BartTokenizer
import gradio as grad
model_name = "facebook/bart-large-mnli"
bart_tokenizer = BartTokenizer.from_pretrained(model_name)
model = BartForSequenceClassification.from_pretrained(model_name)
def classify(text, label):
token_ids = bart_tokenizer.encode(text, label, return_tensors="pt")
token_logits = model(token_ids)[0]
entail_contra_token_logits = token_logits[:, [0, 2]]
probabilities = entail_contra_token_logits.softmax(dim=1)
response = probabilities[:, 1].item() * 100
return response
in_text = grad.Textbox(lines=1, label="English", placeholder="Text to be classified")
in_labels = grad.Textbox(lines=1, label="Label", placeholder="Input a label")
out = grad.Textbox(lines=1, label="Probability of label being true is ")
grad.Interface(classify, inputs=[in_text, in_labels], outputs=[out]).launch() |