import gradio as gr
from transformers import BertTokenizer, BertForSequenceClassification
import torch

# Function to load model and tokenizer
def load_model():
    tokenizer = BertTokenizer.from_pretrained("Minej/bert-base-personality")
    model = BertForSequenceClassification.from_pretrained("Minej/bert-base-personality")
    return tokenizer, model

# Load the model and tokenizer
tokenizer, model = load_model()

# Function to predict personality traits
def personality_detection(text):
    inputs = tokenizer(text, truncation=True, padding=True, return_tensors="pt")
    with torch.no_grad():
        outputs = model(**inputs)
        predictions = torch.nn.functional.softmax(outputs.logits, dim=-1).squeeze().numpy()

    label_names = ['Extroversion', 'Neuroticism', 'Agreeableness', 'Conscientiousness', 'Openness']
    result = {label_names[i]: predictions[i] for i in range(len(label_names))}
    return result

# Create the Gradio interface
interface = gr.Interface(
    fn=personality_detection,
    inputs=gr.Textbox(lines=2, placeholder="Enter a sentence here..."),
    outputs=gr.Label(),
    title="Personality Analyzer",
    description="Enter a sentence and get a prediction of personality traits."
)

# Launch the Gradio app on a specific port
interface.launch(server_port=7861)  # You can change 7861 to another port if necessary