Anushree1 commited on
Commit
e559725
·
verified ·
1 Parent(s): 7c2b64b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -9
app.py CHANGED
@@ -2,10 +2,13 @@ import gradio as gr
2
  import torch
3
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
4
 
5
- # Load Pretrained Model & Tokenizer (Ensure this is a fine-tuned model)
6
- MODEL_NAME = "xlm-roberta-base"
7
- model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=5) # Adjust num_labels as per training
8
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
 
 
 
9
 
10
  # Define Label Mapping (Modify based on your dataset)
11
  LABEL_MAPPING = {
@@ -16,13 +19,17 @@ LABEL_MAPPING = {
16
  4: "Marketing Material"
17
  }
18
 
19
- # Classification Function
20
  def classify_text(text):
 
 
 
 
21
  inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
22
-
23
  with torch.no_grad():
24
  outputs = model(**inputs)
25
-
26
  # Convert logits to probabilities
27
  probs = torch.nn.functional.softmax(outputs.logits, dim=1)
28
 
@@ -31,7 +38,11 @@ def classify_text(text):
31
 
32
  # Retrieve category name
33
  category = LABEL_MAPPING.get(label_idx, "Unknown")
34
-
 
 
 
 
35
  return f"Predicted Category: {category} (Confidence: {probs[0][label_idx]:.2f})"
36
 
37
  # Gradio UI
@@ -39,7 +50,8 @@ demo = gr.Interface(
39
  fn=classify_text,
40
  inputs=gr.Textbox(lines=4, placeholder="Enter business document text..."),
41
  outputs="text",
42
- title="Multilingual Business Document Classifier"
 
43
  )
44
 
45
  demo.launch()
 
2
  import torch
3
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
4
 
5
+ # Load Fine-Tuned Model & Tokenizer (Ensure path points to your fine-tuned model)
6
+ MODEL_PATH = "path_to_fine_tuned_model" # Replace with the correct model path
7
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
8
+ model = AutoModelForSequenceClassification.from_pretrained(MODEL_PATH)
9
+
10
+ # Set model to evaluation mode (Disables dropout for stable predictions)
11
+ model.eval()
12
 
13
  # Define Label Mapping (Modify based on your dataset)
14
  LABEL_MAPPING = {
 
19
  4: "Marketing Material"
20
  }
21
 
22
+ # Optimized Classification Function
23
  def classify_text(text):
24
+ if not text.strip():
25
+ return "Please enter a valid business document text."
26
+
27
+ # Tokenize Input
28
  inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
29
+
30
  with torch.no_grad():
31
  outputs = model(**inputs)
32
+
33
  # Convert logits to probabilities
34
  probs = torch.nn.functional.softmax(outputs.logits, dim=1)
35
 
 
38
 
39
  # Retrieve category name
40
  category = LABEL_MAPPING.get(label_idx, "Unknown")
41
+
42
+ # Debugging Info (Uncomment for testing)
43
+ print(f"Logits: {outputs.logits}")
44
+ print(f"Probabilities: {probs}")
45
+
46
  return f"Predicted Category: {category} (Confidence: {probs[0][label_idx]:.2f})"
47
 
48
  # Gradio UI
 
50
  fn=classify_text,
51
  inputs=gr.Textbox(lines=4, placeholder="Enter business document text..."),
52
  outputs="text",
53
+ title="Multilingual Business Document Classifier",
54
+ description="Classifies business documents into predefined categories using a multilingual model."
55
  )
56
 
57
  demo.launch()