Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -37,6 +37,8 @@ def inference(input_image):
|
|
37 |
if torch.cuda.is_available():
|
38 |
input_batch = input_batch.to('cuda')
|
39 |
model.to('cuda')
|
|
|
|
|
40 |
|
41 |
with torch.no_grad():
|
42 |
output = model(input_batch)
|
@@ -46,6 +48,16 @@ def inference(input_image):
|
|
46 |
# Read the categories
|
47 |
with open("artist_classes.txt", "r") as f:
|
48 |
categories = [s.strip() for s in f.readlines()]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
# Show top categories per image
|
50 |
top5_prob, top5_catid = torch.topk(probabilities, 6)
|
51 |
result = {}
|
|
|
37 |
if torch.cuda.is_available():
|
38 |
input_batch = input_batch.to('cuda')
|
39 |
model.to('cuda')
|
40 |
+
else:
|
41 |
+
model.to('cpu')
|
42 |
|
43 |
with torch.no_grad():
|
44 |
output = model(input_batch)
|
|
|
48 |
# Read the categories
|
49 |
with open("artist_classes.txt", "r") as f:
|
50 |
categories = [s.strip() for s in f.readlines()]
|
51 |
+
|
52 |
+
categories = {
|
53 |
+
0:"vanGogh",
|
54 |
+
1:"Monet",
|
55 |
+
2:"Leonardo da Vinci",
|
56 |
+
3:"Rembrandt",
|
57 |
+
4:"Pablo Picasso",
|
58 |
+
5:"Salvador Dali"
|
59 |
+
}
|
60 |
+
|
61 |
# Show top categories per image
|
62 |
top5_prob, top5_catid = torch.topk(probabilities, 6)
|
63 |
result = {}
|