qc7's picture
Update app.py
068eb5c
raw
history blame
1.72 kB
import streamlit as st
import numpy as np
import pandas as pd
import torch
import transformers
from transformers import TextClassificationPipeline, AutoTokenizer, AutoModelForSequenceClassification
@st.cache(suppress_st_warning=True, hash_funcs={transformers.AutoTokenizer: lambda _: None})
def load_tok_and_model():
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(".")
return tokenizer, model
CATEGORIES = ["Computer Science", "Economics", "Electrical Engineering", "Mathematics",
"Q. Biology", "Q. Finances", "Statistics" , "Physics"]
@st.cache(suppress_st_warning=True, hash_funcs={transformers.AutoTokenizer: lambda _: None})
def forward_pass(title, abstract, tokenizer, model):
title_tensor = torch.tensor(tokenizer(title, padding="max_length", truncation=True, max_length=32)['input_ids'])
abstract_tensor = torch.tensor(tokenizer(abstract, padding="max_length", truncation=True, max_length=480)['input_ids'])
embeddings = torch.cat((title_tensor, abstract_tensor))
assert embeddings.shape == (512,)
with torch.no_grad():
logits = model(embeddings[None])['logits'][0]
assert logits.shape == (8,)
probs = torch.softmax(logits).data.cpu().numpy()
return probs
st.title("Classification of arXiv articles' main topic")
st.markdown("Please provide both summary and title when possible")
tokenizer, model = load_tok_and_model()
title = st.text_area(label='Title', height=200)
abstract = st.text_area(label='Abstract', height=200)
button = st.button('Run classifier')
if button:
probs = forward_pass(title, abstract, tokenizer, model)
st.write(probs)