GwanHyeong's picture
Upload folder using huggingface_hub
8c8af64 verified
# *************************************************************************
# Copyright (2023) Bytedance Inc.
#
# Copyright (2023) DragDiffusion Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# *************************************************************************
import cv2
import numpy as np
import PIL
from PIL import Image
from PIL.ImageOps import exif_transpose
import os
import gradio as gr
import datetime
import pickle
from copy import deepcopy
LENGTH=480 # length of the square area displaying/editing images
def clear_all(length=480):
return gr.Image.update(value=None, height=length, width=length), \
gr.Image.update(value=None, height=length, width=length), \
[], None, None
def mask_image(image,
mask,
color=[255,0,0],
alpha=0.5):
""" Overlay mask on image for visualization purpose.
Args:
image (H, W, 3) or (H, W): input image
mask (H, W): mask to be overlaid
color: the color of overlaid mask
alpha: the transparency of the mask
"""
out = deepcopy(image)
img = deepcopy(image)
img[mask == 1] = color
out = cv2.addWeighted(img, alpha, out, 1-alpha, 0, out)
return out
def store_img(img, length=512):
image, mask = img["image"], np.float32(img["mask"][:, :, 0]) / 255.
height,width,_ = image.shape
image = Image.fromarray(image)
image = exif_transpose(image)
image = image.resize((length,int(length*height/width)), PIL.Image.BILINEAR)
mask = cv2.resize(mask, (length,int(length*height/width)), interpolation=cv2.INTER_NEAREST)
image = np.array(image)
if mask.sum() > 0:
mask = np.uint8(mask > 0)
masked_img = mask_image(image, 1 - mask, color=[0, 0, 0], alpha=0.3)
else:
masked_img = image.copy()
# when new image is uploaded, `selected_points` should be empty
return image, [], masked_img, mask
# user click the image to get points, and show the points on the image
def get_points(img,
sel_pix,
evt: gr.SelectData):
# collect the selected point
sel_pix.append(evt.index)
# draw points
points = []
for idx, point in enumerate(sel_pix):
if idx % 2 == 0:
# draw a red circle at the handle point
cv2.circle(img, tuple(point), 10, (255, 0, 0), -1)
else:
# draw a blue circle at the handle point
cv2.circle(img, tuple(point), 10, (0, 0, 255), -1)
points.append(tuple(point))
# draw an arrow from handle point to target point
if len(points) == 2:
cv2.arrowedLine(img, points[0], points[1], (255, 255, 255), 4, tipLength=0.5)
points = []
return img if isinstance(img, np.ndarray) else np.array(img)
# clear all handle/target points
def undo_points(original_image,
mask):
if mask.sum() > 0:
mask = np.uint8(mask > 0)
masked_img = mask_image(original_image, 1 - mask, color=[0, 0, 0], alpha=0.3)
else:
masked_img = original_image.copy()
return masked_img, []
def save_all(category,
source_image,
image_with_clicks,
mask,
labeler,
prompt,
points,
root_dir='./drag_bench_data'):
if not os.path.isdir(root_dir):
os.mkdir(root_dir)
if not os.path.isdir(os.path.join(root_dir, category)):
os.mkdir(os.path.join(root_dir, category))
save_prefix = labeler + '_' + datetime.datetime.now().strftime("%Y-%m-%d-%H%M-%S")
save_dir = os.path.join(root_dir, category, save_prefix)
if not os.path.isdir(save_dir):
os.mkdir(save_dir)
# save images
Image.fromarray(source_image).save(os.path.join(save_dir, 'original_image.png'))
Image.fromarray(image_with_clicks).save(os.path.join(save_dir, 'user_drag.png'))
# save meta data
meta_data = {
'prompt' : prompt,
'points' : points,
'mask' : mask,
}
with open(os.path.join(save_dir, 'meta_data.pkl'), 'wb') as f:
pickle.dump(meta_data, f)
return save_prefix + " saved!"
with gr.Blocks() as demo:
# UI components for editing real images
with gr.Tab(label="Editing Real Image"):
mask = gr.State(value=None) # store mask
selected_points = gr.State([]) # store points
original_image = gr.State(value=None) # store original input image
with gr.Row():
with gr.Column():
gr.Markdown("""<p style="text-align: center; font-size: 20px">Draw Mask</p>""")
canvas = gr.Image(type="numpy", tool="sketch", label="Draw Mask",
show_label=True, height=LENGTH, width=LENGTH) # for mask painting
with gr.Column():
gr.Markdown("""<p style="text-align: center; font-size: 20px">Click Points</p>""")
input_image = gr.Image(type="numpy", label="Click Points",
show_label=True, height=LENGTH, width=LENGTH) # for points clicking
with gr.Row():
labeler = gr.Textbox(label="Labeler")
category = gr.Dropdown(value="art_work",
label="Image Category",
choices=[
'art_work',
'land_scape',
'building_city_view',
'building_countryside_view',
'animals',
'human_head',
'human_upper_body',
'human_full_body',
'interior_design',
'other_objects',
]
)
prompt = gr.Textbox(label="Prompt")
save_status = gr.Textbox(label="display saving status")
with gr.Row():
undo_button = gr.Button("undo points")
clear_all_button = gr.Button("clear all")
save_button = gr.Button("save")
# event definition
# event for dragging user-input real image
canvas.edit(
store_img,
[canvas],
[original_image, selected_points, input_image, mask]
)
input_image.select(
get_points,
[input_image, selected_points],
[input_image],
)
undo_button.click(
undo_points,
[original_image, mask],
[input_image, selected_points]
)
clear_all_button.click(
clear_all,
[gr.Number(value=LENGTH, visible=False, precision=0)],
[canvas,
input_image,
selected_points,
original_image,
mask]
)
save_button.click(
save_all,
[category,
original_image,
input_image,
mask,
labeler,
prompt,
selected_points,],
[save_status]
)
demo.queue().launch(share=True, debug=True)