Spaces:
Build error
Build error
| import gradio as gr | |
| import joblib | |
| import torch | |
| from l3prune.l3prune import LLMEncoder | |
| #load the model | |
| best_clf = joblib.load("./saved/classifier_llama32.joblib") | |
| encoder = LLMEncoder.from_pretrained( | |
| "./saved/pruned_encoder_llama32", | |
| device_map="cpu", | |
| torch_dtype=torch.bfloat16, | |
| #torch_dtype=torch, | |
| #cache_dir=cache_dir | |
| ) | |
| def classify_prompt(prompt): | |
| #response = client.text_classification(prompt) | |
| #label = response[0]['label'] | |
| #score = response[0]['score'] | |
| #if label == 'hate': | |
| # result = f"Harmful (Confidence: {score:.2%})" | |
| #else: | |
| # result = f"Benign (Confidence: {score:.2%})" | |
| X = encoder.encode([prompt]) | |
| result = best_clf.predict(X)[0] | |
| return "Harmful" if result else "Benign" | |
| demo = gr.Interface( | |
| fn=classify_prompt, | |
| inputs=gr.Textbox(lines=3, placeholder="Enter a prompt to classify..."), | |
| outputs=gr.Textbox(label="Classification Result"), | |
| title="Harmful Prompt Classifier", | |
| description="This app classifies whether a given prompt is potentially harmful or benign.", | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |