File size: 1,712 Bytes
d632c9a
 
 
 
46279fd
 
e559725
46279fd
 
d632c9a
46279fd
 
 
 
 
 
 
 
 
7c2b64b
 
 
 
 
 
 
 
46279fd
d632c9a
 
e559725
d632c9a
 
e559725
7c2b64b
 
 
46279fd
7c2b64b
46279fd
7c2b64b
 
 
e559725
46279fd
d632c9a
 
7c2b64b
 
 
 
46279fd
7c2b64b
d632c9a
46279fd
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification

# Define the model path (Update this with your fine-tuned model's path or Hugging Face repo)
MODEL_PATH = "your-huggingface-username/your-fine-tuned-model"

# Authenticate if using a private model (Uncomment and set your token)
# TOKEN = "your_hf_access_token"

# Load Model & Tokenizer
try:
    tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)  # , use_auth_token=TOKEN if needed
    model = AutoModelForSequenceClassification.from_pretrained(MODEL_PATH)  # , use_auth_token=TOKEN if needed
except Exception as e:
    print(f"Error loading model: {e}")
    exit()

# Label Mapping
LABEL_MAPPING = {
    0: "Contract",
    1: "Invoice",
    2: "Financial Report",
    3: "Legal Notice",
    4: "Marketing Material"
}

# Classification Function
def classify_text(text):
    inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512)

    with torch.no_grad():
        outputs = model(**inputs)

    # Convert logits to probabilities
    probs = torch.nn.functional.softmax(outputs.logits, dim=1)
    
    # Get top predicted label
    label_idx = torch.argmax(probs, dim=1).item()
    confidence = probs[0][label_idx].item()
    
    # Retrieve category name
    category = LABEL_MAPPING.get(label_idx, "Unknown")

    return f"Predicted Category: {category} (Confidence: {confidence:.2f})"

# Gradio UI
demo = gr.Interface(
    fn=classify_text, 
    inputs=gr.Textbox(lines=4, placeholder="Enter business document text..."), 
    outputs="text", 
    title="Multilingual Business Document Classifier"
)

# Run the app
if __name__ == "__main__":
    demo.launch()