import gradio as gr
import torch
from moderation import *


moderation = ModerationModel()
moderation.load_state_dict(torch.load('moderation_model.pth', map_location=torch.device('cpu'))) # Remove map_location if run on gpu
moderation.eval()

def predict_moderation(text):
    embeddings_for_prediction = getEmb(text)
    prediction = predict(moderation, embeddings_for_prediction)
    category_scores = prediction.get('category_scores', {})
    detected = prediction.get('detected', False)
    return category_scores, str(detected)


iface = gr.Interface(fn=predict_moderation,
                     inputs="text",
                     outputs=[gr.Label(label="Category Scores"), gr.Label(label="Detected")],
                     title="Moderation Model",
                     description="Enter text to check for moderation flags.")


iface.launch()