0-ma commited on
Commit
c7d5210
·
verified ·
1 Parent(s): 192d452

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -3
app.py CHANGED
@@ -1,11 +1,37 @@
1
  import gradio as gr
2
- import skimage
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
  labels = []
5
  def predict(img):
6
  img = PILImage.create(img)
7
- pred,pred_idx,probs = detector.predict(img)
8
- return {labels[i]: float(probs[i]) for i in range(len(labels))}
 
 
 
 
 
 
9
 
10
  title = "Geometric Shape Classifier"
11
  description = "A geometric shape setector."
 
1
  import gradio as gr
2
+ import numpy as np
3
+ from PIL import Image
4
+ from transformers import AutoImageProcessor, AutoModelForImageClassification
5
+ import requests
6
+ labels = [
7
+ "None",
8
+ "Circle",
9
+ "Triangle",
10
+ "Square",
11
+ "Pentagon",
12
+ "Hexagon"
13
+ ]
14
+ #images = [Image.open(requests.get("https://raw.githubusercontent.com/0-ma/geometric-shape-detector/main/input/exemple_circle.jpg", stream=True).raw),
15
+ # Image.open(requests.get("https://raw.githubusercontent.com/0-ma/geometric-shape-detector/main/input/exemple_pentagone.jpg", stream=True).raw)]
16
+ feature_extractor = AutoImageProcessor.from_pretrained('0-ma/swin-geometric-shapes-tiny')
17
+ model = AutoModelForImageClassification.from_pretrained('0-ma/swin-geometric-shapes-tiny')
18
+
19
+
20
+ print(predicted_labels)
21
+
22
+
23
 
24
  labels = []
25
  def predict(img):
26
  img = PILImage.create(img)
27
+
28
+ inputs = feature_extractor(images=images, return_tensors="pt")
29
+ logits = model(**inputs)['logits'].cpu().detach().numpy()
30
+ predictions = np.argmax(logits, axis=1)
31
+ predicted_labels = [labels[prediction] for prediction in predictions]
32
+
33
+
34
+ return {"predicted_labels" : predicted_labels , "predictions": predictions}
35
 
36
  title = "Geometric Shape Classifier"
37
  description = "A geometric shape setector."