Anushree1 commited on
Commit
46279fd
·
verified ·
1 Parent(s): e559725

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -21
app.py CHANGED
@@ -2,15 +2,21 @@ import gradio as gr
2
  import torch
3
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
4
 
5
- # Load Fine-Tuned Model & Tokenizer (Ensure path points to your fine-tuned model)
6
- MODEL_PATH = "path_to_fine_tuned_model" # Replace with the correct model path
7
- tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
8
- model = AutoModelForSequenceClassification.from_pretrained(MODEL_PATH)
9
 
10
- # Set model to evaluation mode (Disables dropout for stable predictions)
11
- model.eval()
12
 
13
- # Define Label Mapping (Modify based on your dataset)
 
 
 
 
 
 
 
 
14
  LABEL_MAPPING = {
15
  0: "Contract",
16
  1: "Invoice",
@@ -19,12 +25,8 @@ LABEL_MAPPING = {
19
  4: "Marketing Material"
20
  }
21
 
22
- # Optimized Classification Function
23
  def classify_text(text):
24
- if not text.strip():
25
- return "Please enter a valid business document text."
26
-
27
- # Tokenize Input
28
  inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
29
 
30
  with torch.no_grad():
@@ -33,25 +35,23 @@ def classify_text(text):
33
  # Convert logits to probabilities
34
  probs = torch.nn.functional.softmax(outputs.logits, dim=1)
35
 
36
- # Get predicted label index
37
  label_idx = torch.argmax(probs, dim=1).item()
 
38
 
39
  # Retrieve category name
40
  category = LABEL_MAPPING.get(label_idx, "Unknown")
41
 
42
- # Debugging Info (Uncomment for testing)
43
- print(f"Logits: {outputs.logits}")
44
- print(f"Probabilities: {probs}")
45
-
46
- return f"Predicted Category: {category} (Confidence: {probs[0][label_idx]:.2f})"
47
 
48
  # Gradio UI
49
  demo = gr.Interface(
50
  fn=classify_text,
51
  inputs=gr.Textbox(lines=4, placeholder="Enter business document text..."),
52
  outputs="text",
53
- title="Multilingual Business Document Classifier",
54
- description="Classifies business documents into predefined categories using a multilingual model."
55
  )
56
 
57
- demo.launch()
 
 
 
2
  import torch
3
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
4
 
5
+ # Define the model path (Update this with your fine-tuned model's path or Hugging Face repo)
6
+ MODEL_PATH = "your-huggingface-username/your-fine-tuned-model"
 
 
7
 
8
+ # Authenticate if using a private model (Uncomment and set your token)
9
+ # TOKEN = "your_hf_access_token"
10
 
11
+ # Load Model & Tokenizer
12
+ try:
13
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH) # , use_auth_token=TOKEN if needed
14
+ model = AutoModelForSequenceClassification.from_pretrained(MODEL_PATH) # , use_auth_token=TOKEN if needed
15
+ except Exception as e:
16
+ print(f"Error loading model: {e}")
17
+ exit()
18
+
19
+ # Label Mapping
20
  LABEL_MAPPING = {
21
  0: "Contract",
22
  1: "Invoice",
 
25
  4: "Marketing Material"
26
  }
27
 
28
+ # Classification Function
29
  def classify_text(text):
 
 
 
 
30
  inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
31
 
32
  with torch.no_grad():
 
35
  # Convert logits to probabilities
36
  probs = torch.nn.functional.softmax(outputs.logits, dim=1)
37
 
38
+ # Get top predicted label
39
  label_idx = torch.argmax(probs, dim=1).item()
40
+ confidence = probs[0][label_idx].item()
41
 
42
  # Retrieve category name
43
  category = LABEL_MAPPING.get(label_idx, "Unknown")
44
 
45
+ return f"Predicted Category: {category} (Confidence: {confidence:.2f})"
 
 
 
 
46
 
47
  # Gradio UI
48
  demo = gr.Interface(
49
  fn=classify_text,
50
  inputs=gr.Textbox(lines=4, placeholder="Enter business document text..."),
51
  outputs="text",
52
+ title="Multilingual Business Document Classifier"
 
53
  )
54
 
55
+ # Run the app
56
+ if __name__ == "__main__":
57
+ demo.launch()