|
import streamlit as st |
|
import numpy as np |
|
import pandas as pd |
|
import torch |
|
|
|
import transformers |
|
from transformers import TextClassificationPipeline, AutoTokenizer, AutoModelForSequenceClassification |
|
|
|
@st.cache(suppress_st_warning=True, hash_funcs={transformers.AutoTokenizer: lambda _: None}) |
|
def load_tok_and_model(): |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
model = AutoModelForSequenceClassification.from_pretrained(".") |
|
return tokenizer, model |
|
|
|
|
|
CATEGORIES = ["Computer Science", "Economics", "Electrical Engineering", "Mathematics", |
|
"Q. Biology", "Q. Finances", "Statistics" , "Physics"] |
|
|
|
|
|
@st.cache(suppress_st_warning=True, hash_funcs={transformers.AutoTokenizer: lambda _: None}) |
|
def forward_pass(title, abstract, tokenizer, model): |
|
title_tensor = torch.tensor(tokenizer(title, padding="max_length", truncation=True, max_length=32)['input_ids']) |
|
abstract_tensor = torch.tensor(tokenizer(abstract, padding="max_length", truncation=True, max_length=480)['input_ids']) |
|
|
|
embeddings = torch.cat((title_tensor, abstract_tensor)) |
|
assert embeddings.shape == (512,) |
|
with torch.no_grad(): |
|
logits = model(embeddings[None])['logits'][0] |
|
assert logits.shape == (8,) |
|
probs = torch.softmax(logits).data.cpu().numpy() |
|
|
|
return probs |
|
|
|
st.title("Classification of arXiv articles' main topic") |
|
st.markdown("Please provide both summary and title when possible") |
|
|
|
tokenizer, model = load_tok_and_model() |
|
|
|
title = st.text_area(label='Title', height=200) |
|
abstract = st.text_area(label='Abstract', height=200) |
|
button = st.button('Run classifier') |
|
|
|
if button: |
|
probs = forward_pass(title, abstract, tokenizer, model) |
|
st.write(probs) |
|
|