import gradio as gr
import torch
import torch.nn as nn
from transformers import BertModel, BertTokenizer

# define the labels for the mutli-classification model
class_names = ['Negative', 'Neutral', 'Positive']

# Build the Sentiment Classifier class 
class SentimentClassifier(nn.Module):
    # Constructor class 
    def __init__(self, n_classes):
        super(SentimentClassifier, self).__init__()
        self.bert = BertModel.from_pretrained('lothritz/LuxemBERT')
        self.drop = nn.Dropout(p=0.3)
        self.out = nn.Linear(self.bert.config.hidden_size, n_classes)
    
    # Forward propagaion class
    def forward(self, input_ids, attention_mask):
        _, pooled_output = self.bert(
          input_ids=input_ids,
          attention_mask=attention_mask,
            return_dict=False
        )
        #  Add a dropout layer 
        output = self.drop(pooled_output)
        return self.out(output)
# load the CNN binary classification model
model = SentimentClassifier(len(class_names))
model.load_state_dict(torch.load('./pytorch_model.bin', map_location=torch.device('cpu')))
tokenizer = BertTokenizer.from_pretrained('./')

def encode(text):
    encoded_text = tokenizer.encode_plus(
        text,
        max_length=50,
        add_special_tokens=True,
        return_token_type_ids=False,
        pad_to_max_length=True,
        return_attention_mask=True,
        return_tensors='pt',
    )
    return encoded_text

def classify(text):
    encoded_comment = encode(text)
    input_ids = encoded_comment['input_ids']
    attention_mask = encoded_comment['attention_mask']

    output = model(input_ids, attention_mask)
    _, prediction = torch.max(output, dim=1)
    
    return class_names[prediction]

demo = gr.Interface(fn=classify, inputs="text", outputs="text", title="Sentiment Analyser", description="Text classifer for Luxembourgish")


demo.launch()