Pisethan commited on
Commit
7f5bdbb
·
verified ·
1 Parent(s): 8882087

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -14
app.py CHANGED
@@ -1,7 +1,7 @@
1
  import gradio as gr
2
  from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification
3
 
4
- # Load your model from Hugging Face Hub
5
  MODEL_NAME = "Pisethan/sangapac-math"
6
  reverse_label_mapping = {
7
  0: "arithmetic",
@@ -11,22 +11,31 @@ reverse_label_mapping = {
11
  4: "geometry",
12
  }
13
 
14
- # Load tokenizer and model
15
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
16
- model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME)
17
- classifier = pipeline("text-classification", model=model, tokenizer=tokenizer)
 
 
 
 
18
 
19
  def predict(input_text):
20
- # Predict using the model
21
- result = classifier(input_text)
22
- label_id = int(result[0]["label"].split("_")[-1]) # Extract label ID
23
- category = reverse_label_mapping[label_id] # Map label to category
24
 
25
- # Return prediction result
26
- return {
27
- "Category": category,
28
- "Confidence": result[0]["score"],
29
- }
 
 
 
 
 
 
 
30
 
31
  # Gradio interface
32
  interface = gr.Interface(
 
1
  import gradio as gr
2
  from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification
3
 
4
+ # Model details
5
  MODEL_NAME = "Pisethan/sangapac-math"
6
  reverse_label_mapping = {
7
  0: "arithmetic",
 
11
  4: "geometry",
12
  }
13
 
14
+ # Load model and tokenizer
15
+ try:
16
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
17
+ model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME)
18
+ classifier = pipeline("text-classification", model=model, tokenizer=tokenizer)
19
+ except Exception as e:
20
+ classifier = None
21
+ print(f"Error loading model or tokenizer: {e}")
22
 
23
  def predict(input_text):
24
+ if classifier is None:
25
+ return {"Error": "Model not loaded properly."}
 
 
26
 
27
+ try:
28
+ # Predict the category
29
+ result = classifier(input_text)
30
+ label_id = int(result[0]["label"].split("_")[-1]) # Extract label ID
31
+ category = reverse_label_mapping[label_id] # Map label to category
32
+
33
+ return {
34
+ "Category": category,
35
+ "Confidence": result[0]["score"],
36
+ }
37
+ except Exception as e:
38
+ return {"Error": str(e)}
39
 
40
  # Gradio interface
41
  interface = gr.Interface(