Anushree1's picture
Update app.py
46279fd verified
raw
history blame
1.71 kB
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
# Define the model path (Update this with your fine-tuned model's path or Hugging Face repo)
MODEL_PATH = "your-huggingface-username/your-fine-tuned-model"
# Authenticate if using a private model (Uncomment and set your token)
# TOKEN = "your_hf_access_token"
# Load Model & Tokenizer
try:
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH) # , use_auth_token=TOKEN if needed
model = AutoModelForSequenceClassification.from_pretrained(MODEL_PATH) # , use_auth_token=TOKEN if needed
except Exception as e:
print(f"Error loading model: {e}")
exit()
# Label Mapping
LABEL_MAPPING = {
0: "Contract",
1: "Invoice",
2: "Financial Report",
3: "Legal Notice",
4: "Marketing Material"
}
# Classification Function
def classify_text(text):
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
with torch.no_grad():
outputs = model(**inputs)
# Convert logits to probabilities
probs = torch.nn.functional.softmax(outputs.logits, dim=1)
# Get top predicted label
label_idx = torch.argmax(probs, dim=1).item()
confidence = probs[0][label_idx].item()
# Retrieve category name
category = LABEL_MAPPING.get(label_idx, "Unknown")
return f"Predicted Category: {category} (Confidence: {confidence:.2f})"
# Gradio UI
demo = gr.Interface(
fn=classify_text,
inputs=gr.Textbox(lines=4, placeholder="Enter business document text..."),
outputs="text",
title="Multilingual Business Document Classifier"
)
# Run the app
if __name__ == "__main__":
demo.launch()