|
import os |
|
import gradio as gr |
|
from cluster_visualize import PredictionArgs, generate_visualization |
|
|
|
checkpoint_dir = 'checkpoints' |
|
|
|
models = [] |
|
checkpoints = [] |
|
|
|
for root, dirs, files in os.walk(checkpoint_dir): |
|
for file in files: |
|
if file.endswith(".pth.tar"): |
|
models.append(file.split('.')[0]) |
|
checkpoints.append(os.path.join(root, file)) |
|
|
|
def generate_coc_viz(model_name, |
|
image, |
|
stage, |
|
block, |
|
head, |
|
alpha): |
|
|
|
checkpoint = 'coc_tiny_plain.pth.tar' |
|
args = PredictionArgs( |
|
model=model_name, |
|
image=image, |
|
checkpoint=checkpoint, |
|
stage=stage, |
|
block=block, |
|
head=head, |
|
alpha=alpha |
|
) |
|
coc_visualization, probability = generate_visualization(args) |
|
return probability, coc_visualization |
|
|
|
|
|
demo = gr.Interface( |
|
fn=generate_coc_viz, |
|
inputs=[gr.components.Dropdown(models, label="Model Name"), |
|
gr.Image(label="Input Image"), |
|
gr.Slider(0, 3, step=1, label="Stage"), |
|
gr.Slider(-1, 4, step=1, label="Block"), |
|
gr.Slider(0, 7, step=1, label="Head"), |
|
gr.components.Number(0.5, label="Alpha")], |
|
outputs=[gr.Number(label="Probability"), gr.Image(label="Cluster Visualization")], |
|
) |
|
|
|
demo.launch() |
|
|