TheStinger tuan2308 commited on
Commit
7bb60d6
·
verified ·
1 Parent(s): 70ce5b0

Update app.py (#4)

Browse files

- Update app.py (1e80830cddc8850dba95d28a1c6843c87745ddf9)


Co-authored-by: Tuan <[email protected]>

Files changed (1) hide show
  1. app.py +80 -53
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import gradio as gr
2
  import cv2
3
  import numpy
@@ -5,38 +6,39 @@ import os
5
  import random
6
  from basicsr.archs.rrdbnet_arch import RRDBNet
7
  from basicsr.utils.download_util import load_file_from_url
 
8
  from realesrgan import RealESRGANer
9
  from realesrgan.archs.srvgg_arch import SRVGGNetCompact
10
- from torchvision.transforms.functional import rgb_to_grayscale
11
- import spaces
12
 
13
  last_file = None
14
  img_mode = "RGBA"
15
 
16
  @spaces.GPU
17
  def realesrgan(img, model_name, denoise_strength, face_enhance, outscale):
18
- """Real-ESRGAN function to restore (and upscale) images."""
 
19
  if not img:
20
  return
21
 
22
  # Define model parameters
23
- if model_name == 'RealESRGAN_x4plus':
24
  model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
25
  netscale = 4
26
  file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth']
27
- elif model_name == 'RealESRNet_x4plus':
28
  model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
29
  netscale = 4
30
  file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/RealESRNet_x4plus.pth']
31
- elif model_name == 'RealESRGAN_x4plus_anime_6B':
32
  model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
33
  netscale = 4
34
  file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth']
35
- elif model_name == 'RealESRGAN_x2plus':
36
  model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
37
  netscale = 2
38
  file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth']
39
- elif model_name == 'realesr-general-x4v3':
40
  model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
41
  netscale = 4
42
  file_url = [
@@ -44,19 +46,23 @@ def realesrgan(img, model_name, denoise_strength, face_enhance, outscale):
44
  'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth'
45
  ]
46
 
 
47
  model_path = os.path.join('weights', model_name + '.pth')
48
  if not os.path.isfile(model_path):
49
  ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
50
  for url in file_url:
 
51
  model_path = load_file_from_url(
52
  url=url, model_dir=os.path.join(ROOT_DIR, 'weights'), progress=True, file_name=None)
53
 
 
54
  dni_weight = None
55
  if model_name == 'realesr-general-x4v3' and denoise_strength != 1:
56
  wdn_model_path = model_path.replace('realesr-general-x4v3', 'realesr-general-wdn-x4v3')
57
  model_path = [model_path, wdn_model_path]
58
  dni_weight = [denoise_strength, 1 - denoise_strength]
59
 
 
60
  upsampler = RealESRGANer(
61
  scale=netscale,
62
  model_path=model_path,
@@ -69,6 +75,7 @@ def realesrgan(img, model_name, denoise_strength, face_enhance, outscale):
69
  gpu_id=None
70
  )
71
 
 
72
  if face_enhance:
73
  from gfpgan import GFPGANer
74
  face_enhancer = GFPGANer(
@@ -78,9 +85,11 @@ def realesrgan(img, model_name, denoise_strength, face_enhance, outscale):
78
  channel_multiplier=2,
79
  bg_upsampler=upsampler)
80
 
 
81
  cv_img = numpy.array(img)
82
  img = cv2.cvtColor(cv_img, cv2.COLOR_RGBA2BGRA)
83
 
 
84
  try:
85
  if face_enhance:
86
  _, _, output = face_enhancer.enhance(img, has_aligned=False, only_center_face=False, paste_back=True)
@@ -90,29 +99,49 @@ def realesrgan(img, model_name, denoise_strength, face_enhance, outscale):
90
  print('Error', error)
91
  print('If you encounter CUDA out of memory, try to set --tile with a smaller number.')
92
  else:
93
- extension = 'png' if img_mode == 'RGBA' else 'jpg'
 
 
 
 
94
 
95
  out_filename = f"output_{rnd_string(8)}.{extension}"
96
  cv2.imwrite(out_filename, output)
97
  global last_file
98
  last_file = out_filename
 
99
 
100
- output_img = cv2.cvtColor(output, cv2.COLOR_BGRA2RGBA) if img_mode == "RGBA" else output
101
- return out_filename, image_properties(output_img)
102
 
103
  def rnd_string(x):
 
 
104
  characters = "abcdefghijklmnopqrstuvwxyz_0123456789"
105
- return "".join((random.choice(characters)) for i in range(x))
 
 
106
 
107
  def reset():
 
 
 
108
  global last_file
109
  if last_file:
110
  print(f"Deleting {last_file} ...")
111
  os.remove(last_file)
112
  last_file = None
113
- return gr.update(value=None), gr.update(value=None), gr.update(value=None)
 
114
 
115
  def has_transparency(img):
 
 
 
 
 
 
 
 
 
116
  if img.info.get("transparency", None) is not None:
117
  return True
118
  if img.mode == "P":
@@ -126,70 +155,68 @@ def has_transparency(img):
126
  return True
127
  return False
128
 
 
129
  def image_properties(img):
130
  """Returns the dimensions (width and height) and color mode of the input image and
131
  also sets the global img_mode variable to be used by the realesrgan function
132
  """
133
  global img_mode
134
- if img is None: # Explicitly check for None
135
- return "No image data available."
136
-
137
- if isinstance(img, numpy.ndarray): # Handle NumPy array case
138
- height, width = img.shape[:2]
139
- channels = img.shape[2] if len(img.shape) > 2 else 1
140
- img_mode = "RGBA" if channels == 4 else "RGB" if channels == 3 else "Grayscale"
141
- return f"Resolution: Width: {width}, Height: {height} | Color Mode: {img_mode}"
142
-
143
- if hasattr(img, "info") and hasattr(img, "mode") and hasattr(img, "size"): # Handle PIL images
144
  if has_transparency(img):
145
  img_mode = "RGBA"
146
  else:
147
  img_mode = "RGB"
148
- return f"Resolution: Width: {img.size[0]}, Height: {img.size[1]} | Color Mode: {img_mode}"
149
-
150
- return "Unsupported image format."
151
 
152
  def main():
153
- with gr.Blocks(theme=gr.themes.Default(primary_hue="pink", secondary_hue="rose"), title="Ilaria Upscaler 💖") as app:
 
154
 
155
  gr.Markdown(
156
- """# <div align="center"> Ilaria Upscaler 💖 </div>
157
  """
158
  )
 
159
  with gr.Accordion("Upscaling option"):
160
  with gr.Row():
161
- model_name = gr.Dropdown(label="Model",
162
- choices=["RealESRGAN_x4plus", "RealESRNet_x4plus", "RealESRGAN_x4plus_anime_6B", "RealESRGAN_x2plus", "realesr-general-x4v3"],
163
- value="RealESRGAN_x4plus")
164
- denoise_strength = gr.Slider(label="Denoise Strength", minimum=0, maximum=1, step=0.1, value=0.5)
165
- outscale = gr.Slider(label="Resolution Upscale", minimum=1, maximum=6, step=1, value=4)
166
- face_enhance = gr.Checkbox(label="Face Enhancement")
167
-
 
 
 
 
168
  with gr.Row():
169
  with gr.Group():
170
- input_image = gr.Image(label="Input Image", type="pil")
171
- input_properties = gr.Textbox(label="Input Image Properties", interactive=False)
172
-
173
- with gr.Group():
174
- output_image = gr.Image(label="Output Image")
175
- output_properties = gr.Textbox(label="Output Image Properties", interactive=False)
176
-
177
  with gr.Row():
178
- reset_btn = gr.Button("Reset")
179
- upscale_btn = gr.Button("Upscale")
180
-
181
- input_image.change(fn=image_properties, inputs=input_image, outputs=input_properties)
182
- upscale_btn.click(fn=realesrgan,
183
- inputs=[input_image, model_name, denoise_strength, face_enhance, outscale],
184
- outputs=[output_image, output_properties])
185
- reset_btn.click(fn=reset, inputs=[], outputs=[input_image, output_image, input_properties])
 
 
 
186
 
187
  gr.Markdown(
188
- """Made with love by Ilaria 💖 | Support me on [Ko-Fi](https://ko-fi.com/ilariaowo) | Using [Real-ESRGAN](https://github.com/xinntao/Real-ESRGAN).
189
  """
190
  )
191
 
192
- app.launch()
 
193
 
194
  if __name__ == "__main__":
195
- main()
 
1
+ import spaces
2
  import gradio as gr
3
  import cv2
4
  import numpy
 
6
  import random
7
  from basicsr.archs.rrdbnet_arch import RRDBNet
8
  from basicsr.utils.download_util import load_file_from_url
9
+
10
  from realesrgan import RealESRGANer
11
  from realesrgan.archs.srvgg_arch import SRVGGNetCompact
12
+
 
13
 
14
  last_file = None
15
  img_mode = "RGBA"
16
 
17
  @spaces.GPU
18
  def realesrgan(img, model_name, denoise_strength, face_enhance, outscale):
19
+ """Real-ESRGAN function to restore (and upscale) images.
20
+ """
21
  if not img:
22
  return
23
 
24
  # Define model parameters
25
+ if model_name == 'RealESRGAN_x4plus': # x4 RRDBNet model
26
  model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
27
  netscale = 4
28
  file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth']
29
+ elif model_name == 'RealESRNet_x4plus': # x4 RRDBNet model
30
  model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
31
  netscale = 4
32
  file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/RealESRNet_x4plus.pth']
33
+ elif model_name == 'RealESRGAN_x4plus_anime_6B': # x4 RRDBNet model with 6 blocks
34
  model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
35
  netscale = 4
36
  file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth']
37
+ elif model_name == 'RealESRGAN_x2plus': # x2 RRDBNet model
38
  model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
39
  netscale = 2
40
  file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth']
41
+ elif model_name == 'realesr-general-x4v3': # x4 VGG-style model (S size)
42
  model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
43
  netscale = 4
44
  file_url = [
 
46
  'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth'
47
  ]
48
 
49
+ # Determine model paths
50
  model_path = os.path.join('weights', model_name + '.pth')
51
  if not os.path.isfile(model_path):
52
  ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
53
  for url in file_url:
54
+ # model_path will be updated
55
  model_path = load_file_from_url(
56
  url=url, model_dir=os.path.join(ROOT_DIR, 'weights'), progress=True, file_name=None)
57
 
58
+ # Use dni to control the denoise strength
59
  dni_weight = None
60
  if model_name == 'realesr-general-x4v3' and denoise_strength != 1:
61
  wdn_model_path = model_path.replace('realesr-general-x4v3', 'realesr-general-wdn-x4v3')
62
  model_path = [model_path, wdn_model_path]
63
  dni_weight = [denoise_strength, 1 - denoise_strength]
64
 
65
+ # Restorer Class
66
  upsampler = RealESRGANer(
67
  scale=netscale,
68
  model_path=model_path,
 
75
  gpu_id=None
76
  )
77
 
78
+ # Use GFPGAN for face enhancement
79
  if face_enhance:
80
  from gfpgan import GFPGANer
81
  face_enhancer = GFPGANer(
 
85
  channel_multiplier=2,
86
  bg_upsampler=upsampler)
87
 
88
+ # Convert the input PIL image to cv2 image, so that it can be processed by realesrgan
89
  cv_img = numpy.array(img)
90
  img = cv2.cvtColor(cv_img, cv2.COLOR_RGBA2BGRA)
91
 
92
+ # Apply restoration
93
  try:
94
  if face_enhance:
95
  _, _, output = face_enhancer.enhance(img, has_aligned=False, only_center_face=False, paste_back=True)
 
99
  print('Error', error)
100
  print('If you encounter CUDA out of memory, try to set --tile with a smaller number.')
101
  else:
102
+ # Save restored image and return it to the output Image component
103
+ if img_mode == 'RGBA': # RGBA images should be saved in png format
104
+ extension = 'png'
105
+ else:
106
+ extension = 'jpg'
107
 
108
  out_filename = f"output_{rnd_string(8)}.{extension}"
109
  cv2.imwrite(out_filename, output)
110
  global last_file
111
  last_file = out_filename
112
+ return out_filename
113
 
 
 
114
 
115
  def rnd_string(x):
116
+ """Returns a string of 'x' random characters
117
+ """
118
  characters = "abcdefghijklmnopqrstuvwxyz_0123456789"
119
+ result = "".join((random.choice(characters)) for i in range(x))
120
+ return result
121
+
122
 
123
  def reset():
124
+ """Resets the Image components of the Gradio interface and deletes
125
+ the last processed image
126
+ """
127
  global last_file
128
  if last_file:
129
  print(f"Deleting {last_file} ...")
130
  os.remove(last_file)
131
  last_file = None
132
+ return gr.update(value=None), gr.update(value=None)
133
+
134
 
135
  def has_transparency(img):
136
+ """This function works by first checking to see if a "transparency" property is defined
137
+ in the image's info -- if so, we return "True". Then, if the image is using indexed colors
138
+ (such as in GIFs), it gets the index of the transparent color in the palette
139
+ (img.info.get("transparency", -1)) and checks if it's used anywhere in the canvas
140
+ (img.getcolors()). If the image is in RGBA mode, then presumably it has transparency in
141
+ it, but it double-checks by getting the minimum and maximum values of every color channel
142
+ (img.getextrema()), and checks if the alpha channel's smallest value falls below 255.
143
+ https://stackoverflow.com/questions/43864101/python-pil-check-if-image-is-transparent
144
+ """
145
  if img.info.get("transparency", None) is not None:
146
  return True
147
  if img.mode == "P":
 
155
  return True
156
  return False
157
 
158
+
159
  def image_properties(img):
160
  """Returns the dimensions (width and height) and color mode of the input image and
161
  also sets the global img_mode variable to be used by the realesrgan function
162
  """
163
  global img_mode
164
+ if img:
 
 
 
 
 
 
 
 
 
165
  if has_transparency(img):
166
  img_mode = "RGBA"
167
  else:
168
  img_mode = "RGB"
169
+ properties = f"Resolution: Width: {img.size[0]}, Height: {img.size[1]} | Color Mode: {img_mode}"
170
+ return properties
171
+
172
 
173
  def main():
174
+ # Gradio Interface
175
+ with gr.Blocks(title="Real-ESRGAN Gradio Demo", theme="ParityError/Interstellar") as demo:
176
 
177
  gr.Markdown(
178
+ """ Image Upscaler
179
  """
180
  )
181
+
182
  with gr.Accordion("Upscaling option"):
183
  with gr.Row():
184
+ model_name = gr.Dropdown(label="Upscaler model",
185
+ choices=["RealESRGAN_x4plus", "RealESRNet_x4plus", "RealESRGAN_x4plus_anime_6B",
186
+ "RealESRGAN_x2plus", "realesr-general-x4v3"],
187
+ value="RealESRGAN_x4plus_anime_6B", show_label=True)
188
+ denoise_strength = gr.Slider(label="Denoise Strength",
189
+ minimum=0, maximum=1, step=0.1, value=0.5)
190
+ outscale = gr.Slider(label="Resolution upscale",
191
+ minimum=1, maximum=6, step=1, value=4, show_label=True)
192
+ face_enhance = gr.Checkbox(label="Face Enhancement (GFPGAN)",
193
+ )
194
+
195
  with gr.Row():
196
  with gr.Group():
197
+ input_image = gr.Image(label="Input Image", type="pil", image_mode="RGBA")
198
+ input_image_properties = gr.Textbox(label="Image Properties", max_lines=1)
199
+ output_image = gr.Image(label="Output Image", image_mode="RGBA")
 
 
 
 
200
  with gr.Row():
201
+ reset_btn = gr.Button("Remove images")
202
+ restore_btn = gr.Button("Upscale")
203
+
204
+ # Event listeners:
205
+ input_image.change(fn=image_properties, inputs=input_image, outputs=input_image_properties)
206
+ restore_btn.click(fn=realesrgan,
207
+ inputs=[input_image, model_name, denoise_strength, face_enhance, outscale],
208
+ outputs=output_image)
209
+ reset_btn.click(fn=reset, inputs=[], outputs=[output_image, input_image])
210
+ # reset_btn.click(None, inputs=[], outputs=[input_image], _js="() => (null)\n")
211
+ # Undocumented method to clear a component's value using Javascript
212
 
213
  gr.Markdown(
214
+ """
215
  """
216
  )
217
 
218
+ demo.launch()
219
+
220
 
221
  if __name__ == "__main__":
222
+ main()