ombhojane commited on
Commit
83e1827
·
verified ·
1 Parent(s): 343dde1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -3
app.py CHANGED
@@ -5,18 +5,25 @@ import gradio as gr
5
  # Load model and processor from Hugging Face
6
  model_name = "ombhojane/healthyPlantsModel"
7
 
8
- # Load model (ensure it's a classification model)
9
  model = AutoModelForImageClassification.from_pretrained(model_name)
10
  processor = AutoProcessor.from_pretrained(model_name)
11
 
 
 
 
12
  # Define a function to run inference
13
  def classify_image(image):
14
  inputs = processor(images=image, return_tensors="pt")
15
  with torch.no_grad():
16
  outputs = model(**inputs)
17
  logits = outputs.logits
18
- predicted_class = logits.argmax(-1).item()
19
- return f"Predicted Class: {predicted_class}"
 
 
 
 
20
 
21
  # Create a Gradio interface
22
  demo = gr.Interface(fn=classify_image, inputs="image", outputs="text")
 
5
  # Load model and processor from Hugging Face
6
  model_name = "ombhojane/healthyPlantsModel"
7
 
8
+ # Load model and processor
9
  model = AutoModelForImageClassification.from_pretrained(model_name)
10
  processor = AutoProcessor.from_pretrained(model_name)
11
 
12
+ # Get class labels from model config
13
+ id2label = model.config.id2label # Mapping of indices to class names
14
+
15
  # Define a function to run inference
16
  def classify_image(image):
17
  inputs = processor(images=image, return_tensors="pt")
18
  with torch.no_grad():
19
  outputs = model(**inputs)
20
  logits = outputs.logits
21
+ predicted_class_idx = logits.argmax(-1).item()
22
+
23
+ # Get human-readable class label
24
+ predicted_class_name = id2label.get(predicted_class_idx, "Unknown")
25
+
26
+ return f"Predicted Class: {predicted_class_name}"
27
 
28
  # Create a Gradio interface
29
  demo = gr.Interface(fn=classify_image, inputs="image", outputs="text")