lxy1122 commited on
Commit
9c8b8a2
·
1 Parent(s): eb4f571

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -0
app.py CHANGED
@@ -37,6 +37,8 @@ def inference(input_image):
37
  if torch.cuda.is_available():
38
  input_batch = input_batch.to('cuda')
39
  model.to('cuda')
 
 
40
 
41
  with torch.no_grad():
42
  output = model(input_batch)
@@ -46,6 +48,16 @@ def inference(input_image):
46
  # Read the categories
47
  with open("artist_classes.txt", "r") as f:
48
  categories = [s.strip() for s in f.readlines()]
 
 
 
 
 
 
 
 
 
 
49
  # Show top categories per image
50
  top5_prob, top5_catid = torch.topk(probabilities, 6)
51
  result = {}
 
37
  if torch.cuda.is_available():
38
  input_batch = input_batch.to('cuda')
39
  model.to('cuda')
40
+ else:
41
+ model.to('cpu')
42
 
43
  with torch.no_grad():
44
  output = model(input_batch)
 
48
  # Read the categories
49
  with open("artist_classes.txt", "r") as f:
50
  categories = [s.strip() for s in f.readlines()]
51
+
52
+ categories = {
53
+ 0:"vanGogh",
54
+ 1:"Monet",
55
+ 2:"Leonardo da Vinci",
56
+ 3:"Rembrandt",
57
+ 4:"Pablo Picasso",
58
+ 5:"Salvador Dali"
59
+ }
60
+
61
  # Show top categories per image
62
  top5_prob, top5_catid = torch.topk(probabilities, 6)
63
  result = {}