|
import gradio as gr |
|
import torch |
|
import clip |
|
from PIL import Image |
|
import numpy as np |
|
|
|
device = "cpu" |
|
model, preprocess = clip.load("RN50x64", device=device) |
|
|
|
|
|
def img_process(img1,img2,location_width,location_height,size_width,size_height): |
|
im1=Image.open(img1) |
|
im2=Image.open(img2).convert('RGBA').resize((600,400)) |
|
print(im1.mode) |
|
if im1.mode == 'RGBA': |
|
size = im1.size |
|
im3 = im1.resize((int(size[0]/2),int(size[1]/2))) |
|
r, g, b, a = im3.split() |
|
im2.paste(im3,(50, 50), mask=a) |
|
elif im1.mode == 'RGB': |
|
threshold=240 |
|
size = im1.size |
|
im1 = im1.resize((size_width,size_height)) |
|
im1=im1.convert('RGBA') |
|
arr=np.array(np.asarray(im1)) |
|
r,g,b,a=np.rollaxis(arr,axis=-1) |
|
mask=((r>threshold) |
|
& (g>threshold) |
|
& (b>threshold) |
|
) |
|
arr[mask,3]=0 |
|
im1=Image.fromarray(arr,mode='RGBA') |
|
r, g, b, a = im1.split() |
|
im2.paste(im1,(location_width,location_height,), mask=a) |
|
return im2 |
|
|
|
def itm(obj,back,location_width,location_height,size_width,size_height,is_obj,pos_obj,neg_obj,is_attr,pos_attr,neg_attr): |
|
|
|
img1 = img_process(obj,back,location_width,location_height,size_width,size_height) |
|
img = preprocess(img1).unsqueeze(0) |
|
obj_prompt = neg_obj if is_obj else pos_obj |
|
attr_prompt = neg_attr if is_attr else pos_attr |
|
text = clip.tokenize([f"a photo of {pos_attr} {pos_obj}",f"a photo of {attr_prompt} {obj_prompt}"]) |
|
with torch.no_grad(): |
|
|
|
logits_per_image, logits_per_text = model(img, text) |
|
probs = logits_per_image.softmax(dim=-1).cpu().numpy() |
|
|
|
print("Label probs:", probs) |
|
return f"a photo of {pos_attr} {pos_obj}",probs[0][0],f"a photo of {attr_prompt} {obj_prompt}",probs[0][1],img1 |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("<h1><center>VL-Checklist Demo</center></h1>") |
|
gr.Markdown(""" |
|
Tips: |
|
- In this demo, you can change the object and attribute of object in the text prompt, and you can also change the size and location of the object. |
|
- Please upload an object image with white background. |
|
- The model we used in the demo is CLIP. |
|
""") |
|
with gr.Row(): |
|
with gr.Column(): |
|
img_obj = gr.Image(value ='sample/apple.png',type = "filepath",label='object_img(Plz input an object with white background)') |
|
|
|
loc_w = gr.Slider(maximum = 500,label='location_width',step=1) |
|
loc_h = gr.Slider(maximum = 300,label='location_height',step=1) |
|
s_w = gr.Number(value =200,precision=0,label='size_width') |
|
s_h = gr.Number(value =200,precision=0,label='size_height') |
|
gr.Markdown("Click **Submit** to get the output!") |
|
with gr.Column(): |
|
img_back = gr.Image(value ='sample/back.jpg',type = "filepath",label='background_img') |
|
is_obj = gr.Checkbox(value = True,label='Does negative prompt change the object?') |
|
pos_obj = gr.Textbox(value = 'apple',label='positive object') |
|
neg_obj = gr.Textbox(value = 'dog',label='negative object') |
|
is_attr = gr.Checkbox(value = False,label='Does negative prompt change the attribute?') |
|
pos_attr = gr.Textbox(value = 'red',label='positive attribute') |
|
neg_attr = gr.Textbox(value = 'green',label='negative attribute') |
|
with gr.Row(): |
|
btn = gr.Button("Submit",variant="primary") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
img_output = gr.Image(type = "pil",label='output_img') |
|
with gr.Column(): |
|
pos_prom = gr.Textbox(label='Positive prompt') |
|
pos_s = gr.Textbox(label='Positive score') |
|
neg_prom = gr.Textbox(label='Negative prompt') |
|
neg_s = gr.Textbox(label='Negative score') |
|
|
|
with gr.Row(): |
|
gr.Examples([['sample/apple.png', 'sample/back.jpg',50,50,200,200,True,'apple','dog',False,'red','green'], |
|
['sample/banana.jpg', 'sample/back.jpg',300,200,200,200,True,'bananas','peaches',False,'yellow','green']], |
|
[img_obj,img_back,loc_w,loc_h,s_w,s_h,is_obj,pos_obj,neg_obj,is_attr,pos_attr,neg_attr], |
|
[pos_prom,pos_s,neg_prom,neg_s,img_output],itm,True) |
|
|
|
btn.click(fn=itm,inputs=[img_obj,img_back,loc_w,loc_h,s_w,s_h,is_obj,pos_obj,neg_obj,is_attr,pos_attr,neg_attr], |
|
outputs=[pos_prom,pos_s,neg_prom,neg_s,img_output], |
|
) |
|
|
|
|
|
demo.launch() |
|
|