CurHarsh commited on
Commit
5bb056f
·
1 Parent(s): 7699de1

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -0
app.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ from cluster_visualize import PredictionArgs, generate_visualization
4
+
5
+ checkpoint_dir = 'checkpoints'
6
+
7
+ models = []
8
+ checkpoints = []
9
+
10
+ for root, dirs, files in os.walk(checkpoint_dir):
11
+ for file in files:
12
+ if file.endswith(".pth.tar"):
13
+ models.append(file.split('.')[0])
14
+ checkpoints.append(os.path.join(root, file))
15
+
16
+ def generate_coc_viz(model_name,
17
+ image,
18
+ stage,
19
+ block,
20
+ head,
21
+ alpha):
22
+ model_index = models.index(model_name)
23
+ checkpoint = checkpoints[model_index]
24
+ args = PredictionArgs(
25
+ model=model_name,
26
+ image=image,
27
+ checkpoint=checkpoint,
28
+ stage=stage,
29
+ block=block,
30
+ head=head,
31
+ alpha=alpha
32
+ )
33
+ coc_visualization, probability = generate_visualization(args)
34
+ return probability, coc_visualization
35
+
36
+
37
+ demo = gr.Interface(
38
+ fn=generate_coc_viz,
39
+ inputs=[gr.components.Dropdown(models, label="Model Name"),
40
+ gr.Image(label="Input Image"),
41
+ gr.Slider(0, 3, step=1, label="Stage"),
42
+ gr.Slider(-1, 4, step=1, label="Block"),
43
+ gr.Slider(0, 7, step=1, label="Head"),
44
+ gr.components.Number(0.5, label="Alpha")],
45
+ outputs=[gr.Number(label="Probability"), gr.Image(label="Cluster Visualization")],
46
+ )
47
+
48
+ demo.launch(share=True)