File size: 1,387 Bytes
5bb056f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3972689
ea5d667
5bb056f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2bfbc42
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import os
import gradio as gr
from cluster_visualize import PredictionArgs, generate_visualization

checkpoint_dir = 'checkpoints'

models = []
checkpoints = []

for root, dirs, files in os.walk(checkpoint_dir):
    for file in files:
        if file.endswith(".pth.tar"):
            models.append(file.split('.')[0])
            checkpoints.append(os.path.join(root, file))

def generate_coc_viz(model_name,
                 image,
                 stage,
                 block,
                 head,
                 alpha):
    # model_index = models.index(model_name)
    checkpoint = 'coc_tiny_plain.pth.tar'
    args = PredictionArgs(
        model=model_name,
        image=image,
        checkpoint=checkpoint,
        stage=stage,
        block=block,
        head=head,
        alpha=alpha
    )
    coc_visualization, probability = generate_visualization(args)
    return probability, coc_visualization


demo = gr.Interface(
    fn=generate_coc_viz,
    inputs=[gr.components.Dropdown(models, label="Model Name"),
            gr.Image(label="Input Image"),
            gr.Slider(0, 3, step=1, label="Stage"),
            gr.Slider(-1, 4, step=1, label="Block"),
            gr.Slider(0, 7, step=1, label="Head"),
            gr.components.Number(0.5, label="Alpha")],
    outputs=[gr.Number(label="Probability"), gr.Image(label="Cluster Visualization")],
)

demo.launch()