COC-VIZ / app.py
CurHarsh's picture
Update app.py
2bfbc42
import os
import gradio as gr
from cluster_visualize import PredictionArgs, generate_visualization
checkpoint_dir = 'checkpoints'
models = []
checkpoints = []
for root, dirs, files in os.walk(checkpoint_dir):
for file in files:
if file.endswith(".pth.tar"):
models.append(file.split('.')[0])
checkpoints.append(os.path.join(root, file))
def generate_coc_viz(model_name,
image,
stage,
block,
head,
alpha):
# model_index = models.index(model_name)
checkpoint = 'coc_tiny_plain.pth.tar'
args = PredictionArgs(
model=model_name,
image=image,
checkpoint=checkpoint,
stage=stage,
block=block,
head=head,
alpha=alpha
)
coc_visualization, probability = generate_visualization(args)
return probability, coc_visualization
demo = gr.Interface(
fn=generate_coc_viz,
inputs=[gr.components.Dropdown(models, label="Model Name"),
gr.Image(label="Input Image"),
gr.Slider(0, 3, step=1, label="Stage"),
gr.Slider(-1, 4, step=1, label="Block"),
gr.Slider(0, 7, step=1, label="Head"),
gr.components.Number(0.5, label="Alpha")],
outputs=[gr.Number(label="Probability"), gr.Image(label="Cluster Visualization")],
)
demo.launch()