Anushree1 commited on
Commit
7c2b64b
·
verified ·
1 Parent(s): a12bc7c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -7
app.py CHANGED
@@ -2,21 +2,44 @@ import gradio as gr
2
  import torch
3
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
4
 
5
- # Load Pretrained Model & Tokenizer (XLM-Roberta for multilingual text classification)
6
  MODEL_NAME = "xlm-roberta-base"
7
- model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=5)
8
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
9
 
 
 
 
 
 
 
 
 
 
10
  # Classification Function
11
  def classify_text(text):
12
  inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
 
13
  with torch.no_grad():
14
  outputs = model(**inputs)
15
- label = torch.argmax(outputs.logits, dim=1).item()
16
- return f"Predicted Category: {label}"
 
 
 
 
 
 
 
 
 
17
 
18
  # Gradio UI
19
- demo = gr.Interface(fn=classify_text, inputs=gr.Textbox(lines=2, placeholder="Enter business document text..."),
20
- outputs="text", title="Multilingual Business Document Classifier")
 
 
 
 
21
 
22
- demo.launch()
 
2
  import torch
3
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
4
 
5
+ # Load Pretrained Model & Tokenizer (Ensure this is a fine-tuned model)
6
  MODEL_NAME = "xlm-roberta-base"
7
+ model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=5) # Adjust num_labels as per training
8
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
9
 
10
+ # Define Label Mapping (Modify based on your dataset)
11
+ LABEL_MAPPING = {
12
+ 0: "Contract",
13
+ 1: "Invoice",
14
+ 2: "Financial Report",
15
+ 3: "Legal Notice",
16
+ 4: "Marketing Material"
17
+ }
18
+
19
  # Classification Function
20
  def classify_text(text):
21
  inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
22
+
23
  with torch.no_grad():
24
  outputs = model(**inputs)
25
+
26
+ # Convert logits to probabilities
27
+ probs = torch.nn.functional.softmax(outputs.logits, dim=1)
28
+
29
+ # Get predicted label index
30
+ label_idx = torch.argmax(probs, dim=1).item()
31
+
32
+ # Retrieve category name
33
+ category = LABEL_MAPPING.get(label_idx, "Unknown")
34
+
35
+ return f"Predicted Category: {category} (Confidence: {probs[0][label_idx]:.2f})"
36
 
37
  # Gradio UI
38
+ demo = gr.Interface(
39
+ fn=classify_text,
40
+ inputs=gr.Textbox(lines=4, placeholder="Enter business document text..."),
41
+ outputs="text",
42
+ title="Multilingual Business Document Classifier"
43
+ )
44
 
45
+ demo.launch()