offry commited on
Commit
ab29ae2
·
1 Parent(s): f681eb3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -5
app.py CHANGED
@@ -371,9 +371,9 @@ def predict(inp):
371
  with torch.no_grad():
372
  classification, reconstruction = model(image_tensor.unsqueeze(0))
373
  periods_classification = periods_model(image_tensor.unsqueeze(0))
374
- # recon_tensor = reconstruction[0].repeat(3, 1, 1)
375
- # recon_tensor = invTrans(kornia.enhance.invert(recon_tensor))
376
- # plot_recon = recon_tensor.permute(1, 2, 0).detach().numpy()
377
  w, h = inp.size
378
  # plot_recon = resize(plot_recon, (h, w))
379
  m = nn.Sigmoid()
@@ -405,7 +405,7 @@ def predict(inp):
405
  dists[image_name] = dist
406
  res = dict(sorted(dists.items(), key=itemgetter(1)))
407
  fig, names = create_retrieval_figure(res)
408
- return confidences, periods_confidences, fig, names
409
 
410
 
411
  gr.Interface(fn=predict,
@@ -415,4 +415,4 @@ gr.Interface(fn=predict,
415
  "https://arxiv.org/abs/2211.09480.",
416
  examples=['anth.jpg', 'beetle_snakes.jpg', 'bird.jpg', 'cross.jpg', 'ibex.jpg',
417
  'lion.jpg', 'lion2.jpg', 'sa.jpg'],
418
- outputs=[gr.Label(num_top_classes=1), gr.Label(num_top_classes=1), 'plot', 'text'], ).launch(share=True, enable_queue=True)
 
371
  with torch.no_grad():
372
  classification, reconstruction = model(image_tensor.unsqueeze(0))
373
  periods_classification = periods_model(image_tensor.unsqueeze(0))
374
+ recon_tensor = reconstruction[0].repeat(3, 1, 1)
375
+ recon_tensor = invTrans(kornia.enhance.invert(recon_tensor))
376
+ plot_recon = recon_tensor.permute(1, 2, 0).detach().numpy()
377
  w, h = inp.size
378
  # plot_recon = resize(plot_recon, (h, w))
379
  m = nn.Sigmoid()
 
405
  dists[image_name] = dist
406
  res = dict(sorted(dists.items(), key=itemgetter(1)))
407
  fig, names = create_retrieval_figure(res)
408
+ return confidences, periods_confidences, plot_recon, fig, names
409
 
410
 
411
  gr.Interface(fn=predict,
 
415
  "https://arxiv.org/abs/2211.09480.",
416
  examples=['anth.jpg', 'beetle_snakes.jpg', 'bird.jpg', 'cross.jpg', 'ibex.jpg',
417
  'lion.jpg', 'lion2.jpg', 'sa.jpg'],
418
+ outputs=[gr.Label(num_top_classes=1), gr.Label(num_top_classes=1), "image", 'plot', 'text'], ).launch(share=True, enable_queue=True)