barry-ravichandran commited on
Commit
70d905f
·
1 Parent(s): c73bbfb

Updated app to support gradio v4.7.1

Browse files
Files changed (1) hide show
  1. app.py +29 -71
app.py CHANGED
@@ -1,23 +1,7 @@
1
- # ---
2
- # jupyter:
3
- # jupytext:
4
- # text_representation:
5
- # extension: .py
6
- # format_name: light
7
- # format_version: '1.5'
8
- # jupytext_version: 1.15.2
9
- # kernelspec:
10
- # display_name: Python 3
11
- # language: python
12
- # name: python3
13
- # ---
14
-
15
- # # Gradio Example <a name="XAITK-Saliency-Gradio-Example"></a>
16
- # This notebook makes use of the saliency generation example found in the base ``xaitk-saliency`` repo [here](https://github.com/XAITK/xaitk-saliency/blob/master/examples/OcclusionSaliency.ipynb), and explores integrating ``xaitk-saliency`` with ``Gradio`` to create an interactive interface for computing saliency maps.
17
- #
18
- # ## Test Image <a name="Test-Image-Gradio"></a>
19
-
20
- # +
21
  import os
22
  import PIL.Image
23
  import matplotlib.pyplot as plt # type: ignore
@@ -27,11 +11,10 @@ import numpy as np
27
  import gradio as gr
28
  from gradio import ( # type: ignore
29
  AnnotatedImage, Button, Column, Image, Label, # type: ignore
30
- Number, Plot, Row, TabItem, Tab, Tabs # type: ignore
 
31
  )
32
- from gradio import components as gr_components # type: ignore
33
 
34
- # +
35
  # State variables for Image Classification
36
  from gr_component_state import ( # type: ignore
37
  img_cls_model_name, img_cls_saliency_algo_name, window_size_state, stride_state, debiased_state,
@@ -65,25 +48,16 @@ from gr_component_state import ( # type: ignore
65
  import torch
66
  import torchvision.transforms as transforms
67
  import torchvision.models as models
 
68
 
69
  from smqtk_detection.impls.detect_image_objects.resnet_frcnn import ResNetFRCNN
70
  from xaitk_saliency.impls.gen_image_classifier_blackbox_sal.slidingwindow import SlidingWindowStack
71
  from xaitk_saliency.impls.gen_image_classifier_blackbox_sal.rise import RISEStack
72
  from xaitk_saliency.impls.gen_object_detector_blackbox_sal.drise import RandomGridStack, DRISEStack
73
-
74
- import torch.nn.functional
75
- from smqtk_classifier.interfaces.classify_image import ClassifyImage
76
-
77
- import numpy as np
78
- from gradio import ( # type: ignore
79
- Checkbox, Dropdown, SelectData, Slider, Textbox # type: ignore
80
- )
81
- from gradio import processing_utils as gr_processing_utils # type: ignore
82
  from xaitk_saliency.interfaces.gen_object_detector_blackbox_sal import GenerateObjectDetectorBlackboxSaliency
83
  from smqtk_detection.interfaces.detect_image_objects import DetectImageObjects
 
84
 
85
- # Use JPEG format for inline visualizations here.
86
- # %config InlineBackend.figure_format = "jpeg"
87
 
88
  os.makedirs('data', exist_ok=True)
89
  test_image_filename = 'data/catdog.jpg'
@@ -91,15 +65,6 @@ urllib.request.urlretrieve('https://farm1.staticflickr.com/74/202734059_fcce636d
91
  plt.figure(figsize=(12, 8))
92
  plt.axis('off')
93
  _ = plt.imshow(PIL.Image.open(test_image_filename))
94
- # -
95
-
96
- # ## Initialize state variables for Gradio components <a name="Global-State-Gradio"></a>
97
- # Gradio expects either a list or dict format to maintain state variables based on the use case. The cell below initializes the state variables from the ``gr_component_state.py`` file for the various components in our gradio demo.
98
-
99
-
100
-
101
- # ## Helper Functions <a name="Helper-Functions-Gradio"></a>
102
- # The functions defined in the cell below are used to set up the model, saliency algorithm, class labels and image transforms needed for the demo.
103
 
104
  CUDA_AVAILABLE = torch.cuda.is_available()
105
 
@@ -273,31 +238,31 @@ sal_obj_labels, sal_obj_idxs = get_det_sal_labels(obj_classes_file)
273
  # Modify textbox parameters based on chosen saliency algorithm
274
  def show_textbox_parameters(choice):
275
  if choice == 'RISE':
276
- return Textbox.update(visible=False), Textbox.update(visible=False), Textbox.update(visible=True), Textbox.update(visible=True), Textbox.update(visible=True)
277
  elif choice == 'SlidingWindowStack':
278
- return Textbox.update(visible=True), Textbox.update(visible=True), Textbox.update(visible=False), Textbox.update(visible=False), Textbox.update(visible=False)
279
  elif choice == "RandomGridStack":
280
- return Textbox.update(visible=True), Textbox.update(visible=False), Textbox.update(visible=True), Textbox.update(visible=True)
281
  elif choice == "DRISE":
282
- return Textbox.update(visible=True), Textbox.update(visible=True), Textbox.update(visible=True), Textbox.update(visible=False)
283
  else:
284
  raise Exception("Unknown Input")
285
 
286
  # Modify slider parameters based on chosen saliency algorithm
287
  def show_slider_parameters(choice):
288
  if choice == 'RISE' or choice == 'RandomGridStack' or choice == 'DRISE':
289
- return Slider.update(visible=True), Slider.update(visible=True)
290
  elif choice == 'SlidingWindowStack':
291
- return Slider.update(visible=True), Slider.update(visible=False)
292
  else:
293
  raise Exception("Unknown Input")
294
 
295
  # Modify checkbox parameters based on chosen saliency algorithm
296
  def show_debiased_checkbox(choice):
297
  if choice == 'RISE':
298
- return Checkbox.update(visible=True)
299
  elif choice == 'SlidingWindowStack' or choice == 'RandomGridStack' or choice == 'DRISE':
300
- return Checkbox.update(visible=False)
301
  else:
302
  raise Exception("Unknown Input")
303
 
@@ -313,7 +278,7 @@ def predict(x,top_n_classes):
313
  labels = list(zip(sal_class_labels, class_conf[sal_class_idxs].tolist()))
314
  final_labels = dict(sorted(labels, key=lambda t: t[1],reverse=True)[:top_n_classes])
315
 
316
- return final_labels, Dropdown(choices=list(final_labels),label="Class to compute saliency",interactive=True,visible=True)
317
 
318
  # Interpretation function for image classification that implements the selected saliency algorithm and generates the class-wise saliency map visualizations
319
  def interpretation_function(image: np.ndarray,
@@ -390,7 +355,7 @@ def run_detect(input_img: np.ndarray, num_detections: int):
390
 
391
  bboxes_list = bboxes[:,:].astype(int).tolist()
392
 
393
- return (input_img, list(zip([f for f in bboxes_list], [l for l in final_label]))[:num_detections]), Dropdown(choices=[l for l in final_label][:num_detections],label="Detection to compute saliency",interactive=True,visible=True)
394
 
395
  # Run saliency algorithm on the object detect predictions and generate corresponding visualizations
396
  def run_detect_saliency(input_img: np.ndarray,
@@ -456,11 +421,6 @@ def gen_det_saliency(input_img: np.ndarray,
456
 
457
  return sal_maps
458
 
459
- # Event handler that populates the dropdown list of classes based on the Label/AnnotatedImage components' output
460
- def map_labels(evt: SelectData):
461
-
462
- return str(evt.value)
463
-
464
  with gr.Blocks() as demo:
465
  with Tab("Image Classification"):
466
  with Row():
@@ -471,19 +431,19 @@ with gr.Blocks() as demo:
471
  with Row():
472
  with Column(scale=0.33):
473
  window_size = Textbox(value=window_size_state[-1],label="Tuple of window size values (Press Enter to submit the input)",interactive=True,visible=False)
474
- masks = Number(value=num_masks_state[-1],label="Number of Random Masks (Press Enter to submit the input)",interactive=True,visible=False,precision=0)
475
  with Column(scale=0.33):
476
  stride = Textbox(value=stride_state[-1],label="Tuple of stride values (Press Enter to submit the input)" ,interactive=True,visible=False)
477
- spatial_res = Number(value=spatial_res_state[-1],label="Spatial Resolution of Masking Grid (Press Enter to submit the input)" ,interactive=True,visible=False,precision=0)
478
  with Column(scale=0.33):
479
- threads = Slider(value=threads_state[-1],label="Threads",interactive=True,visible=False)
480
  with Row():
481
  with Column(scale=0.33):
482
- seed = Number(value=seed_state[-1],label="Seed (Press Enter to submit the input)",interactive=True,visible=False,precision=0)
483
  with Column(scale=0.33):
484
- p1 = Slider(value=p1_state[-1],label="P1",interactive=True,visible=False, minimum=0,maximum=1,step=0.1)
485
  with Column(scale=0.33):
486
- debiased = Checkbox(value=debiased_state[-1],label="Debiased", interactive=True, visible=False)
487
  with Row():
488
  with Column():
489
  input_img = Image(label="Saliency Map Generation", width=640, height=480)
@@ -515,14 +475,14 @@ with gr.Blocks() as demo:
515
  drop_list_detect_sal = Dropdown(value=obj_det_saliency_algo_name[-1],choices=["RandomGridStack","DRISE"],label="Choose Saliency Algorithm",interactive="True")
516
  with Row():
517
  with Column(scale=0.33):
518
- masks_detect = Number(value=num_masks_state[-1],label="Number of Random Masks (Press Enter to submit the input)",interactive=True,visible=False,precision=0)
519
  occlusion_grid_size = Textbox(value=occlusion_grid_state[-1],label="Tuple of occlusion grid size values (Press Enter to submit the input)",interactive=True,visible=False)
520
- spatial_res_detect = Number(value=spatial_res_state[-1],label="Spatial Resolution of Masking Grid (Press Enter to submit the input)" ,interactive=True,visible=False,precision=0)
521
  with Column(scale=0.33):
522
- seed_detect = Number(value=seed_state[-1],label="Seed (Press Enter to submit the input)",interactive=True,visible=False,precision=0)
523
- p1_detect = Slider(value=p1_state[-1],label="P1",interactive=True,visible=False, minimum=0,maximum=1,step=0.1)
524
  with Column(scale=0.33):
525
- threads_detect = Slider(value=threads_state[-1],label="Threads",interactive=True,visible=False)
526
  with Row():
527
  with Column():
528
  input_img_detect = Image(label="Saliency Map Generation", width=640, height=480)
@@ -565,7 +525,6 @@ with gr.Blocks() as demo:
565
 
566
  # Image Classification prediction and saliency generation event listeners
567
  classify.click(predict, [input_img, num_classes], [class_label,class_name])
568
- class_label.select(map_labels,None,class_name)
569
  generate_saliency.click(interpretation_function, [input_img, class_label, class_name, img_alpha, sal_alpha, min_sal_range, max_sal_range], [interpretation_plot])
570
 
571
  # Object Detection dropdown list event listeners
@@ -584,7 +543,6 @@ with gr.Blocks() as demo:
584
 
585
  # Object detection prediction, class selection and saliency generation event listeners
586
  detection.click(run_detect, [input_img_detect, num_detections], [detect_label,class_name_det])
587
- detect_label.select(map_labels, None, class_name_det)
588
  generate_det_saliency.click(run_detect_saliency,[input_img_detect, num_detections, class_name_det, img_alpha_det, sal_alpha_det, min_sal_range_det, max_sal_range_det],det_saliency_plot)
589
 
590
  demo.launch()
 
1
+ ## Gradio Example
2
+
3
+ # This app makes use of the saliency generation example found in the base ``xaitk-saliency`` repo [here](https://github.com/XAITK/xaitk-saliency/blob/master/examples/OcclusionSaliency.ipynb), and explores integrating ``xaitk-saliency`` with ``Gradio`` to create an interactive interface for computing saliency maps.
4
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  import os
6
  import PIL.Image
7
  import matplotlib.pyplot as plt # type: ignore
 
11
  import gradio as gr
12
  from gradio import ( # type: ignore
13
  AnnotatedImage, Button, Column, Image, Label, # type: ignore
14
+ Number, Plot, Row, TabItem, Tab, Tabs, # type: ignore
15
+ Checkbox, Dropdown, Slider, Textbox # type: ignore
16
  )
 
17
 
 
18
  # State variables for Image Classification
19
  from gr_component_state import ( # type: ignore
20
  img_cls_model_name, img_cls_saliency_algo_name, window_size_state, stride_state, debiased_state,
 
48
  import torch
49
  import torchvision.transforms as transforms
50
  import torchvision.models as models
51
+ import torch.nn.functional
52
 
53
  from smqtk_detection.impls.detect_image_objects.resnet_frcnn import ResNetFRCNN
54
  from xaitk_saliency.impls.gen_image_classifier_blackbox_sal.slidingwindow import SlidingWindowStack
55
  from xaitk_saliency.impls.gen_image_classifier_blackbox_sal.rise import RISEStack
56
  from xaitk_saliency.impls.gen_object_detector_blackbox_sal.drise import RandomGridStack, DRISEStack
 
 
 
 
 
 
 
 
 
57
  from xaitk_saliency.interfaces.gen_object_detector_blackbox_sal import GenerateObjectDetectorBlackboxSaliency
58
  from smqtk_detection.interfaces.detect_image_objects import DetectImageObjects
59
+ from smqtk_classifier.interfaces.classify_image import ClassifyImage
60
 
 
 
61
 
62
  os.makedirs('data', exist_ok=True)
63
  test_image_filename = 'data/catdog.jpg'
 
65
  plt.figure(figsize=(12, 8))
66
  plt.axis('off')
67
  _ = plt.imshow(PIL.Image.open(test_image_filename))
 
 
 
 
 
 
 
 
 
68
 
69
  CUDA_AVAILABLE = torch.cuda.is_available()
70
 
 
238
  # Modify textbox parameters based on chosen saliency algorithm
239
  def show_textbox_parameters(choice):
240
  if choice == 'RISE':
241
+ return Textbox(visible=False), Textbox(visible=False), Textbox(visible=True), Textbox(visible=True), Textbox(visible=True)
242
  elif choice == 'SlidingWindowStack':
243
+ return Textbox(visible=True), Textbox(visible=True), Textbox(visible=False), Textbox(visible=False), Textbox(visible=False)
244
  elif choice == "RandomGridStack":
245
+ return Textbox(visible=True), Textbox(visible=False), Textbox(visible=True), Textbox(visible=True)
246
  elif choice == "DRISE":
247
+ return Textbox(visible=True), Textbox(visible=True), Textbox(visible=True), Textbox(visible=False)
248
  else:
249
  raise Exception("Unknown Input")
250
 
251
  # Modify slider parameters based on chosen saliency algorithm
252
  def show_slider_parameters(choice):
253
  if choice == 'RISE' or choice == 'RandomGridStack' or choice == 'DRISE':
254
+ return Slider(visible=True), Slider(visible=True)
255
  elif choice == 'SlidingWindowStack':
256
+ return Slider(visible=True), Slider(visible=False)
257
  else:
258
  raise Exception("Unknown Input")
259
 
260
  # Modify checkbox parameters based on chosen saliency algorithm
261
  def show_debiased_checkbox(choice):
262
  if choice == 'RISE':
263
+ return Checkbox(visible=True)
264
  elif choice == 'SlidingWindowStack' or choice == 'RandomGridStack' or choice == 'DRISE':
265
+ return Checkbox(visible=False)
266
  else:
267
  raise Exception("Unknown Input")
268
 
 
278
  labels = list(zip(sal_class_labels, class_conf[sal_class_idxs].tolist()))
279
  final_labels = dict(sorted(labels, key=lambda t: t[1],reverse=True)[:top_n_classes])
280
 
281
+ return final_labels, Dropdown(choices=list(final_labels))
282
 
283
  # Interpretation function for image classification that implements the selected saliency algorithm and generates the class-wise saliency map visualizations
284
  def interpretation_function(image: np.ndarray,
 
355
 
356
  bboxes_list = bboxes[:,:].astype(int).tolist()
357
 
358
+ return (input_img, list(zip([f for f in bboxes_list], [l for l in final_label]))[:num_detections]), Dropdown(choices=[l for l in final_label][:num_detections])
359
 
360
  # Run saliency algorithm on the object detect predictions and generate corresponding visualizations
361
  def run_detect_saliency(input_img: np.ndarray,
 
421
 
422
  return sal_maps
423
 
 
 
 
 
 
424
  with gr.Blocks() as demo:
425
  with Tab("Image Classification"):
426
  with Row():
 
431
  with Row():
432
  with Column(scale=0.33):
433
  window_size = Textbox(value=window_size_state[-1],label="Tuple of window size values (Press Enter to submit the input)",interactive=True,visible=False)
434
+ masks = Number(value=num_masks_state[-1],label="Number of Random Masks (Press Enter to submit the input)",interactive=True,visible=True,precision=0)
435
  with Column(scale=0.33):
436
  stride = Textbox(value=stride_state[-1],label="Tuple of stride values (Press Enter to submit the input)" ,interactive=True,visible=False)
437
+ spatial_res = Number(value=spatial_res_state[-1],label="Spatial Resolution of Masking Grid (Press Enter to submit the input)" ,interactive=True,visible=True,precision=0)
438
  with Column(scale=0.33):
439
+ threads = Slider(value=threads_state[-1],label="Threads",interactive=True,visible=True)
440
  with Row():
441
  with Column(scale=0.33):
442
+ seed = Number(value=seed_state[-1],label="Seed (Press Enter to submit the input)",interactive=True,visible=True,precision=0)
443
  with Column(scale=0.33):
444
+ p1 = Slider(value=p1_state[-1],label="P1",interactive=True,visible=True, minimum=0,maximum=1,step=0.1)
445
  with Column(scale=0.33):
446
+ debiased = Checkbox(value=debiased_state[-1],label="Debiased", interactive=True, visible=True)
447
  with Row():
448
  with Column():
449
  input_img = Image(label="Saliency Map Generation", width=640, height=480)
 
475
  drop_list_detect_sal = Dropdown(value=obj_det_saliency_algo_name[-1],choices=["RandomGridStack","DRISE"],label="Choose Saliency Algorithm",interactive="True")
476
  with Row():
477
  with Column(scale=0.33):
478
+ masks_detect = Number(value=num_masks_state[-1],label="Number of Random Masks (Press Enter to submit the input)",interactive=True,visible=True,precision=0)
479
  occlusion_grid_size = Textbox(value=occlusion_grid_state[-1],label="Tuple of occlusion grid size values (Press Enter to submit the input)",interactive=True,visible=False)
480
+ spatial_res_detect = Number(value=spatial_res_state[-1],label="Spatial Resolution of Masking Grid (Press Enter to submit the input)" ,interactive=True,visible=True,precision=0)
481
  with Column(scale=0.33):
482
+ seed_detect = Number(value=seed_state[-1],label="Seed (Press Enter to submit the input)",interactive=True,visible=True,precision=0)
483
+ p1_detect = Slider(value=p1_state[-1],label="P1",interactive=True,visible=True, minimum=0,maximum=1,step=0.1)
484
  with Column(scale=0.33):
485
+ threads_detect = Slider(value=threads_state[-1],label="Threads",interactive=True,visible=True)
486
  with Row():
487
  with Column():
488
  input_img_detect = Image(label="Saliency Map Generation", width=640, height=480)
 
525
 
526
  # Image Classification prediction and saliency generation event listeners
527
  classify.click(predict, [input_img, num_classes], [class_label,class_name])
 
528
  generate_saliency.click(interpretation_function, [input_img, class_label, class_name, img_alpha, sal_alpha, min_sal_range, max_sal_range], [interpretation_plot])
529
 
530
  # Object Detection dropdown list event listeners
 
543
 
544
  # Object detection prediction, class selection and saliency generation event listeners
545
  detection.click(run_detect, [input_img_detect, num_detections], [detect_label,class_name_det])
 
546
  generate_det_saliency.click(run_detect_saliency,[input_img_detect, num_detections, class_name_det, img_alpha_det, sal_alpha_det, min_sal_range_det, max_sal_range_det],det_saliency_plot)
547
 
548
  demo.launch()