GoBoKyung commited on
Commit
c1ae8c8
·
1 Parent(s): 95342a4
Files changed (1) hide show
  1. app.py +6 -9
app.py CHANGED
@@ -1,5 +1,4 @@
1
  import gradio as gr
2
-
3
  from matplotlib import gridspec
4
  import matplotlib.pyplot as plt
5
  import numpy as np
@@ -75,8 +74,9 @@ def draw_plot(pred_img, seg):
75
  ax.tick_params(width=0.0, labelsize=25)
76
  return fig
77
 
78
- def sepia(input_img):
79
- input_img = Image.fromarray(input_img)
 
80
 
81
  inputs = feature_extractor(images=input_img, return_tensors="tf")
82
  outputs = model(**inputs)
@@ -85,16 +85,15 @@ def sepia(input_img):
85
  logits = tf.transpose(logits, [0, 2, 3, 1])
86
  logits = tf.image.resize(
87
  logits, input_img.size[::-1]
88
- ) # We reverse the shape of `image` because `image.size` returns width and height.
89
  seg = tf.math.argmax(logits, axis=-1)[0]
90
 
91
  color_seg = np.zeros(
92
  (seg.shape[0], seg.shape[1], 3), dtype=np.uint8
93
- ) # height, width, 3
94
  for label, color in enumerate(colormap):
95
  color_seg[seg.numpy() == label, :] = color
96
 
97
- # Show image + mask
98
  pred_img = np.array(input_img) * 0.5 + color_seg * 0.5
99
  pred_img = pred_img.astype(np.uint8)
100
 
@@ -102,11 +101,9 @@ def sepia(input_img):
102
  return fig
103
 
104
  demo = gr.Interface(fn=sepia,
105
- inputs=gr.Textbox(text="Enter text here"),
106
  outputs=['plot'],
107
  examples=["img_1.jpg", "img_2.jpeg", "img_3.jpg", "img_4.jpg", "img_5.png"],
108
  allow_flagging='never')
109
 
110
-
111
-
112
  demo.launch()
 
1
  import gradio as gr
 
2
  from matplotlib import gridspec
3
  import matplotlib.pyplot as plt
4
  import numpy as np
 
74
  ax.tick_params(width=0.0, labelsize=25)
75
  return fig
76
 
77
+ def sepia(input_text):
78
+ # Load the image using the input text (assumed to be a path to an image)
79
+ input_img = Image.open(input_text)
80
 
81
  inputs = feature_extractor(images=input_img, return_tensors="tf")
82
  outputs = model(**inputs)
 
85
  logits = tf.transpose(logits, [0, 2, 3, 1])
86
  logits = tf.image.resize(
87
  logits, input_img.size[::-1]
88
+ )
89
  seg = tf.math.argmax(logits, axis=-1)[0]
90
 
91
  color_seg = np.zeros(
92
  (seg.shape[0], seg.shape[1], 3), dtype=np.uint8
93
+ )
94
  for label, color in enumerate(colormap):
95
  color_seg[seg.numpy() == label, :] = color
96
 
 
97
  pred_img = np.array(input_img) * 0.5 + color_seg * 0.5
98
  pred_img = pred_img.astype(np.uint8)
99
 
 
101
  return fig
102
 
103
  demo = gr.Interface(fn=sepia,
104
+ inputs=gr.Textbox(text="Enter image file path"),
105
  outputs=['plot'],
106
  examples=["img_1.jpg", "img_2.jpeg", "img_3.jpg", "img_4.jpg", "img_5.png"],
107
  allow_flagging='never')
108
 
 
 
109
  demo.launch()