File size: 1,387 Bytes
5bb056f 3972689 ea5d667 5bb056f 2bfbc42 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 |
import os
import gradio as gr
from cluster_visualize import PredictionArgs, generate_visualization
checkpoint_dir = 'checkpoints'
models = []
checkpoints = []
for root, dirs, files in os.walk(checkpoint_dir):
for file in files:
if file.endswith(".pth.tar"):
models.append(file.split('.')[0])
checkpoints.append(os.path.join(root, file))
def generate_coc_viz(model_name,
image,
stage,
block,
head,
alpha):
# model_index = models.index(model_name)
checkpoint = 'coc_tiny_plain.pth.tar'
args = PredictionArgs(
model=model_name,
image=image,
checkpoint=checkpoint,
stage=stage,
block=block,
head=head,
alpha=alpha
)
coc_visualization, probability = generate_visualization(args)
return probability, coc_visualization
demo = gr.Interface(
fn=generate_coc_viz,
inputs=[gr.components.Dropdown(models, label="Model Name"),
gr.Image(label="Input Image"),
gr.Slider(0, 3, step=1, label="Stage"),
gr.Slider(-1, 4, step=1, label="Block"),
gr.Slider(0, 7, step=1, label="Head"),
gr.components.Number(0.5, label="Alpha")],
outputs=[gr.Number(label="Probability"), gr.Image(label="Cluster Visualization")],
)
demo.launch()
|