Spaces:
Runtime error
Runtime error
import streamlit as st | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import pandas as pd | |
import torch | |
import tokenizers | |
import transformers | |
from transformers import TextClassificationPipeline, AutoTokenizer, AutoModelForSequenceClassification | |
def load_tok_and_model(): | |
tokenizer = AutoTokenizer.from_pretrained('distilbert-base-uncased') | |
model = AutoModelForSequenceClassification.from_pretrained(".") | |
return tokenizer, model | |
tag = ['Cs', 'Econ', 'EESS', 'Math', 'Physics', 'Q-bio', 'Q-fin', 'Stat'] | |
inv_map = {3: 'Math', 4: 'Physics', 5: 'Q-bio', 0: 'Cs', 6: 'Q-fin', 7: 'Stat', 2: 'EESS', 1: 'Econ'} | |
def predict_label(title, summary, tokenizer, model, inv_map): | |
abstract = title.lower() + '. ' + summary.lower() | |
token_text = tokenizer.encode(abstract) | |
with torch.no_grad(): | |
logits = model(torch.as_tensor([token_text]))[0] | |
probs = torch.softmax(logits[-1, :], dim=-1).data.numpy() | |
idx_label = np.argsort(probs)[::-1] | |
sum_probs = 0 | |
prediction_probs = [] | |
prediction_classes = [] | |
idx = 0 | |
while sum_probs < 0.95: | |
cur_predict = inv_map[idx_label[idx]] | |
cur_probs = probs[idx_label[idx]] | |
sum_probs += cur_probs | |
prediction_probs.append(int(100 * cur_probs)) | |
prediction_classes.append(cur_predict) | |
idx += 1 | |
return prediction_classes, prediction_probs, probs | |
st.title("Classifier of possible topics of articles π") | |
st.markdown("Please insert the summary and/or title of the article below") | |
tokenizer, model = load_tok_and_model() | |
title = st.text_area(label='Title', height=50) | |
abstract = st.text_area(label='Summary', height=150) | |
if st.button('Start classifier'): | |
if title == '' and abstract == '': | |
st.markdown("Summary and title should be filled in in the text area above") | |
else: | |
prediction_classes, prediction_probs, probs = predict_label(title, abstract, tokenizer, model, inv_map) | |
data = pd.DataFrame({'Categories' : tag, 'Probs' : probs}) | |
data = data.sort_values(by='Probs', ascending=False) | |
fig, ax = plt.subplots() | |
ax.bar(data['Categories'], data['Probs']) | |
ax.bar(prediction_classes, prediction_probs) | |
data_answer = pd.DataFrame({'Categories' : prediction_classes, 'Probs, %' : prediction_probs}) | |
st.pyplot(fig) | |
st.write('top-95%') | |
st.write(data_answer) |