import streamlit as st
from streamlit import session_state
# Load model directly
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from transformers import pipeline

tokenizer = AutoTokenizer.from_pretrained("themeetjani/tweet-classification")
model = AutoModelForSequenceClassification.from_pretrained("themeetjani/tweet-classification")

classifier = pipeline("text-classification", model= model,  tokenizer = tokenizer, truncation=True, max_length=512)

st.set_page_config(page_title="Classification", page_icon="📈")
if 'tweet_class' not in session_state:
    session_state['tweet_class']= ""

def classify(tweet):
    predicted_classes= session_state['tweet_class']= classifier(tweet, top_k=1)
    print (tweet)
    print (predicted_classes)
    session_state['tweet_class'] = predicted_classes[0]['label']

st.title("Tweet Classifier")

tweet= st.text_area(label= "Please write the tweet bellow", 
              placeholder="What does the tweet say?")

st.text_area("result", value=session_state['tweet_class'])

st.button("Classify", on_click=classify, args=[tweet])