Spaces:
Sleeping
Sleeping
import gradio as gr | |
import joblib | |
# Define the class names | |
class_names = [ | |
'Family Issues', | |
'Relationship Conflicts', | |
'Work Dynamics', | |
'Financial and Legal Disagreements', | |
'Personal Boundaries', | |
'Cultural and Identity-Based Issues', | |
'Other' | |
] | |
# Define the custom pipeline | |
class CustomSVMTextClassificationPipeline: | |
def __init__(self, model_path, vectorizer_path): | |
# Load the model and vectorizer | |
self.model = joblib.load(model_path) | |
self.vectorizer = joblib.load(vectorizer_path) | |
def __call__(self, texts): | |
if isinstance(texts, str): | |
texts = [texts] # Ensure input is a list | |
# Preprocess input using the vectorizer | |
preprocessed_texts = self.vectorizer.transform(texts) | |
# Predict using the model | |
predictions = self.model.predict(preprocessed_texts) | |
# Convert predictions into readable format (class names) | |
results = [] | |
for pred in predictions: | |
predicted_classes = [class_names[i] for i, value in enumerate(pred) if value == 1] | |
results.append(predicted_classes) | |
return results if len(results) > 1 else results[0] # Return a single result for single input | |
# Load the model and vectorizer | |
model_path = "svm_multi_output_model.pkl" # Replace with your model file path | |
vectorizer_path = "tfidf_vectorizer.pkl" # Replace with your vectorizer file path | |
classifier = CustomSVMTextClassificationPipeline(model_path, vectorizer_path) | |
def classify_text(input_text): | |
""" | |
Classify the input text using the custom pipeline. | |
""" | |
results = classifier(input_text) | |
return results | |
# Create the Gradio interface | |
with gr.Blocks() as app: | |
gr.Markdown("# Text Classification App") | |
gr.Markdown("Enter text to classify:") | |
input_text = gr.Textbox(label="Input Text") | |
output = gr.JSON(label="Classification Results") | |
submit_button = gr.Button("Classify") | |
submit_button.click(classify_text, inputs=[input_text], outputs=[output]) | |
# Launch the app | |
if __name__ == "__main__": | |
app.launch() | |