Update app.py
Browse files
app.py
CHANGED
@@ -371,9 +371,9 @@ def predict(inp):
|
|
371 |
with torch.no_grad():
|
372 |
classification, reconstruction = model(image_tensor.unsqueeze(0))
|
373 |
periods_classification = periods_model(image_tensor.unsqueeze(0))
|
374 |
-
|
375 |
-
|
376 |
-
|
377 |
w, h = inp.size
|
378 |
# plot_recon = resize(plot_recon, (h, w))
|
379 |
m = nn.Sigmoid()
|
@@ -405,7 +405,7 @@ def predict(inp):
|
|
405 |
dists[image_name] = dist
|
406 |
res = dict(sorted(dists.items(), key=itemgetter(1)))
|
407 |
fig, names = create_retrieval_figure(res)
|
408 |
-
return confidences, periods_confidences, fig, names
|
409 |
|
410 |
|
411 |
gr.Interface(fn=predict,
|
@@ -415,4 +415,4 @@ gr.Interface(fn=predict,
|
|
415 |
"https://arxiv.org/abs/2211.09480.",
|
416 |
examples=['anth.jpg', 'beetle_snakes.jpg', 'bird.jpg', 'cross.jpg', 'ibex.jpg',
|
417 |
'lion.jpg', 'lion2.jpg', 'sa.jpg'],
|
418 |
-
outputs=[gr.Label(num_top_classes=1), gr.Label(num_top_classes=1), 'plot', 'text'], ).launch(share=True, enable_queue=True)
|
|
|
371 |
with torch.no_grad():
|
372 |
classification, reconstruction = model(image_tensor.unsqueeze(0))
|
373 |
periods_classification = periods_model(image_tensor.unsqueeze(0))
|
374 |
+
recon_tensor = reconstruction[0].repeat(3, 1, 1)
|
375 |
+
recon_tensor = invTrans(kornia.enhance.invert(recon_tensor))
|
376 |
+
plot_recon = recon_tensor.permute(1, 2, 0).detach().numpy()
|
377 |
w, h = inp.size
|
378 |
# plot_recon = resize(plot_recon, (h, w))
|
379 |
m = nn.Sigmoid()
|
|
|
405 |
dists[image_name] = dist
|
406 |
res = dict(sorted(dists.items(), key=itemgetter(1)))
|
407 |
fig, names = create_retrieval_figure(res)
|
408 |
+
return confidences, periods_confidences, plot_recon, fig, names
|
409 |
|
410 |
|
411 |
gr.Interface(fn=predict,
|
|
|
415 |
"https://arxiv.org/abs/2211.09480.",
|
416 |
examples=['anth.jpg', 'beetle_snakes.jpg', 'bird.jpg', 'cross.jpg', 'ibex.jpg',
|
417 |
'lion.jpg', 'lion2.jpg', 'sa.jpg'],
|
418 |
+
outputs=[gr.Label(num_top_classes=1), gr.Label(num_top_classes=1), "image", 'plot', 'text'], ).launch(share=True, enable_queue=True)
|