import gradio as gr
from transformers import pipeline

# Create a zero-shot classification pipeline
classifier = pipeline("zero-shot-classification")


def classify_text(text, additional_labels):
    # Default labels
    labels = ["Education", "Business", "Sports", "Manufacturing"]

    # Add custom labels if provided
    if additional_labels:
        custom_labels = additional_labels.split(',')
        labels.extend(custom_labels)

    # Perform classification
    result = classifier(text, candidate_labels=labels)

    # Formatting the output
    output = []
    for label, score in zip(result["labels"], result["scores"]):
        output.append(f"Label: {label}, Score: {round(score, 4)}")
    return "\n".join(output)


# Create a Gradio interface
interface = gr.Interface(
    fn=classify_text,
    inputs=["text", "text"],
    outputs="text",
    title="Text Classification",
    description="Enter a text to classify into categories: Education, Business, Sports, Manufacturing. Optionally, add more categories separated by commas."
)

# Launch the interface
interface.launch()