Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
# Define the model path (Update this with your fine-tuned model's path or Hugging Face repo) | |
MODEL_PATH = "your-huggingface-username/your-fine-tuned-model" | |
# Authenticate if using a private model (Uncomment and set your token) | |
# TOKEN = "your_hf_access_token" | |
# Load Model & Tokenizer | |
try: | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH) # , use_auth_token=TOKEN if needed | |
model = AutoModelForSequenceClassification.from_pretrained(MODEL_PATH) # , use_auth_token=TOKEN if needed | |
except Exception as e: | |
print(f"Error loading model: {e}") | |
exit() | |
# Label Mapping | |
LABEL_MAPPING = { | |
0: "Contract", | |
1: "Invoice", | |
2: "Financial Report", | |
3: "Legal Notice", | |
4: "Marketing Material" | |
} | |
# Classification Function | |
def classify_text(text): | |
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512) | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
# Convert logits to probabilities | |
probs = torch.nn.functional.softmax(outputs.logits, dim=1) | |
# Get top predicted label | |
label_idx = torch.argmax(probs, dim=1).item() | |
confidence = probs[0][label_idx].item() | |
# Retrieve category name | |
category = LABEL_MAPPING.get(label_idx, "Unknown") | |
return f"Predicted Category: {category} (Confidence: {confidence:.2f})" | |
# Gradio UI | |
demo = gr.Interface( | |
fn=classify_text, | |
inputs=gr.Textbox(lines=4, placeholder="Enter business document text..."), | |
outputs="text", | |
title="Multilingual Business Document Classifier" | |
) | |
# Run the app | |
if __name__ == "__main__": | |
demo.launch() | |