ameerazam08 commited on
Commit
ffe8dd1
·
verified ·
1 Parent(s): 9a44348

Upload folder using huggingface_hub

Browse files
Files changed (38) hide show
  1. .gitattributes +1 -0
  2. README.md +50 -0
  3. app.py +341 -0
  4. data/ckpt/realisticVisionV60B1_v51VAE/feature_extractor/preprocessor_config.json +27 -0
  5. data/ckpt/realisticVisionV60B1_v51VAE/model_index.json +37 -0
  6. data/ckpt/realisticVisionV60B1_v51VAE/safety_checker/config.json +28 -0
  7. data/ckpt/realisticVisionV60B1_v51VAE/safety_checker/pytorch_model.bin +3 -0
  8. data/ckpt/realisticVisionV60B1_v51VAE/scheduler/scheduler_config.json +15 -0
  9. data/ckpt/realisticVisionV60B1_v51VAE/text_encoder/config.json +24 -0
  10. data/ckpt/realisticVisionV60B1_v51VAE/text_encoder/pytorch_model.bin +3 -0
  11. data/ckpt/realisticVisionV60B1_v51VAE/tokenizer/merges.txt +0 -0
  12. data/ckpt/realisticVisionV60B1_v51VAE/tokenizer/special_tokens_map.json +30 -0
  13. data/ckpt/realisticVisionV60B1_v51VAE/tokenizer/tokenizer_config.json +30 -0
  14. data/ckpt/realisticVisionV60B1_v51VAE/tokenizer/vocab.json +0 -0
  15. data/ckpt/realisticVisionV60B1_v51VAE/unet/config.json +67 -0
  16. data/ckpt/realisticVisionV60B1_v51VAE/unet/diffusion_pytorch_model.bin +3 -0
  17. data/ckpt/realisticVisionV60B1_v51VAE/vae/config.json +31 -0
  18. data/ckpt/realisticVisionV60B1_v51VAE/vae/diffusion_pytorch_model.bin +3 -0
  19. data/ckpt/sam_vit_h_4b8939.pth +3 -0
  20. data/ckpt/segmentation_mask_brushnet_ckpt/config.json +58 -0
  21. data/ckpt/segmentation_mask_brushnet_ckpt/diffusion_pytorch_model.safetensors +3 -0
  22. examples/brushnet/src/example_1.jpg +0 -0
  23. examples/brushnet/src/example_1_mask.jpg +0 -0
  24. examples/brushnet/src/example_1_result.png +0 -0
  25. examples/brushnet/src/example_3.jpg +0 -0
  26. examples/brushnet/src/example_3_mask.jpg +0 -0
  27. examples/brushnet/src/example_3_result.png +0 -0
  28. examples/brushnet/src/example_4.jpeg +0 -0
  29. examples/brushnet/src/example_4_mask.jpg +0 -0
  30. examples/brushnet/src/example_4_result.png +0 -0
  31. examples/brushnet/src/example_5.jpg +0 -0
  32. examples/brushnet/src/example_5_mask.jpg +0 -0
  33. examples/brushnet/src/example_5_result.png +0 -0
  34. examples/brushnet/src/test_image.jpg +0 -0
  35. examples/brushnet/src/test_mask.jpg +0 -0
  36. examples/brushnet/src/test_result.png +0 -0
  37. mask.png +0 -0
  38. requirements.txt +19 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ file filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: BrushNet
3
+ emoji: ⚡
4
+ colorFrom: yellow
5
+ colorTo: indigo
6
+ sdk: gradio
7
+ sdk_version: 3.50.2
8
+ python_version: 3.9
9
+ app_file: app.py
10
+ pinned: false
11
+ license: apache-2.0
12
+ ---
13
+
14
+ # BrushNet
15
+
16
+ This repository contains the gradio demo of the paper "BrushNet: A Plug-and-Play Image Inpainting Model with Decomposed Dual-Branch Diffusion"
17
+
18
+ Keywords: Image Inpainting, Diffusion Models, Image Generation
19
+
20
+ > [Xuan Ju](https://github.com/juxuan27)<sup>12</sup>, [Xian Liu](https://alvinliu0.github.io/)<sup>12</sup>, [Xintao Wang](https://xinntao.github.io/)<sup>1*</sup>, [Yuxuan Bian](https://scholar.google.com.hk/citations?user=HzemVzoAAAAJ&hl=zh-CN&oi=ao)<sup>2</sup>, [Ying Shan](https://www.linkedin.com/in/YingShanProfile/)<sup>1</sup>, [Qiang Xu](https://cure-lab.github.io/)<sup>2*</sup><br>
21
+ > <sup>1</sup>ARC Lab, Tencent PCG <sup>2</sup>The Chinese University of Hong Kong <sup>*</sup>Corresponding Author
22
+
23
+
24
+ <p align="center">
25
+ <a href="https://tencentarc.github.io/BrushNet/">Project Page</a> |
26
+ <a href="https://github.com/TencentARC/BrushNet">Code</a> |
27
+ <a href="https://arxiv.org/abs/2403.06976">Arxiv</a> |
28
+ <a href="https://forms.gle/9TgMZ8tm49UYsZ9s5">Data</a> |
29
+ <a href="https://drive.google.com/file/d/1IkEBWcd2Fui2WHcckap4QFPcCI0gkHBh/view">Video</a> |
30
+ </p>
31
+
32
+
33
+ ## 🤝🏼 Cite Us
34
+
35
+ ```
36
+ @misc{ju2024brushnet,
37
+ title={BrushNet: A Plug-and-Play Image Inpainting Model with Decomposed Dual-Branch Diffusion},
38
+ author={Xuan Ju and Xian Liu and Xintao Wang and Yuxuan Bian and Ying Shan and Qiang Xu},
39
+ year={2024},
40
+ eprint={2403.06976},
41
+ archivePrefix={arXiv},
42
+ primaryClass={cs.CV}
43
+ }
44
+ ```
45
+
46
+
47
+ ## 💖 Acknowledgement
48
+ <span id="acknowledgement"></span>
49
+
50
+ Our code is modified based on [diffusers](https://github.com/huggingface/diffusers), thanks to all the contributors!
app.py ADDED
@@ -0,0 +1,341 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ##!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import os
4
+
5
+ # print("Installing correct gradio version...")
6
+ # os.system("pip uninstall -y gradio")
7
+ # os.system("pip install gradio==3.50.0")
8
+ # print("Installing Finished!")
9
+
10
+ ##!/usr/bin/python3
11
+ # -*- coding: utf-8 -*-
12
+ import gradio as gr
13
+ import os
14
+ import cv2
15
+ from PIL import Image
16
+ import numpy as np
17
+ from segment_anything import SamPredictor, sam_model_registry
18
+ import torch
19
+ from diffusers import StableDiffusionBrushNetPipeline, BrushNetModel, UniPCMultistepScheduler
20
+ import random
21
+
22
+ mobile_sam = sam_model_registry['vit_h'](checkpoint='data/ckpt/sam_vit_h_4b8939.pth').to("cuda")
23
+ mobile_sam.eval()
24
+ mobile_predictor = SamPredictor(mobile_sam)
25
+ colors = [(255, 0, 0), (0, 255, 0)]
26
+ markers = [1, 5]
27
+
28
+ # - - - - - examples - - - - - #
29
+ image_examples = [
30
+ ["examples/brushnet/src/test_image.jpg", "A beautiful cake on the table", "examples/brushnet/src/test_mask.jpg", 0, [], [Image.open("examples/brushnet/src/test_result.png")]],
31
+ ["examples/brushnet/src/example_1.jpg", "A man in Chinese traditional clothes", "examples/brushnet/src/example_1_mask.jpg", 1, [], [Image.open("examples/brushnet/src/example_1_result.png")]],
32
+ ["examples/brushnet/src/example_3.jpg", "a cut toy on the table", "examples/brushnet/src/example_3_mask.jpg", 2, [], [Image.open("examples/brushnet/src/example_3_result.png")]],
33
+ ["examples/brushnet/src/example_4.jpeg", "a car driving in the wild", "examples/brushnet/src/example_4_mask.jpg", 3, [], [Image.open("examples/brushnet/src/example_4_result.png")]],
34
+ ["examples/brushnet/src/example_5.jpg", "a charming woman wearing dress standing in the dark forest", "examples/brushnet/src/example_5_mask.jpg", 4, [], [Image.open("examples/brushnet/src/example_5_result.png")]],
35
+ ]
36
+
37
+
38
+ # choose the base model here
39
+ # base_model_path = "data/ckpt/realisticVisionV60B1_v51VAE"
40
+ base_model_path = "runwayml/stable-diffusion-v1-5"
41
+
42
+ # input brushnet ckpt path
43
+ brushnet_path = "data/ckpt/segmentation_mask_brushnet_ckpt"
44
+
45
+ brushnet = BrushNetModel.from_pretrained(brushnet_path, torch_dtype=torch.float16,safety_checker=None)
46
+ pipe = StableDiffusionBrushNetPipeline.from_pretrained(
47
+ base_model_path, brushnet=brushnet, torch_dtype=torch.float16, low_cpu_mem_usage=False,safety_checker=None
48
+ )
49
+
50
+ # speed up diffusion process with faster scheduler and memory optimization
51
+ pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
52
+ # remove following line if xformers is not installed or when using Torch 2.0.
53
+ # pipe.enable_xformers_memory_efficient_attention()
54
+ # memory optimization.
55
+ pipe.enable_model_cpu_offload()
56
+
57
+ def resize_image(input_image, resolution):
58
+ H, W, C = input_image.shape
59
+ H = float(H)
60
+ W = float(W)
61
+ k = float(resolution) / min(H, W)
62
+ H *= k
63
+ W *= k
64
+ H = int(np.round(H / 64.0)) * 64
65
+ W = int(np.round(W / 64.0)) * 64
66
+ img = cv2.resize(input_image, (W, H), interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA)
67
+ return img
68
+
69
+
70
+ def process(input_image,
71
+ original_image,
72
+ original_mask,
73
+ input_mask,
74
+ selected_points,
75
+ prompt,
76
+ negative_prompt,
77
+ blended,
78
+ invert_mask,
79
+ control_strength,
80
+ seed,
81
+ randomize_seed,
82
+ guidance_scale,
83
+ num_inference_steps):
84
+ if original_image is None:
85
+ raise gr.Error('Please upload the input image')
86
+ if (original_mask is None or len(selected_points)==0) and input_mask is None:
87
+ raise gr.Error("Please click the region where you hope unchanged/changed, or upload a white-black Mask image")
88
+
89
+ # load example image
90
+ if isinstance(original_image, int):
91
+ image_name = image_examples[original_image][0]
92
+ original_image = cv2.imread(image_name)
93
+ original_image = cv2.cvtColor(original_image, cv2.COLOR_BGR2RGB)
94
+
95
+ if input_mask is not None:
96
+ H,W=original_image.shape[:2]
97
+ original_mask = cv2.resize(input_mask, (W, H))
98
+ else:
99
+ original_mask = np.clip(255 - original_mask, 0, 255).astype(np.uint8)
100
+
101
+ if invert_mask:
102
+ original_mask=255-original_mask
103
+
104
+ mask = 1.*(original_mask.sum(-1)>255)[:,:,np.newaxis]
105
+ masked_image = original_image * (1-mask)
106
+
107
+ init_image = Image.fromarray(masked_image.astype(np.uint8)).convert("RGB")
108
+ mask_image = Image.fromarray(original_mask.astype(np.uint8)).convert("RGB")
109
+ mask_image.save("./mask.png")
110
+
111
+ generator = torch.Generator("cuda").manual_seed(random.randint(0,2147483647) if randomize_seed else seed)
112
+ image_num = 3
113
+ image = pipe(
114
+ [prompt]*image_num,
115
+ init_image,
116
+ mask_image,
117
+ num_inference_steps=num_inference_steps,
118
+ guidance_scale=guidance_scale,
119
+ generator=generator,
120
+ brushnet_conditioning_scale=float(control_strength),
121
+ negative_prompt=[negative_prompt]*image_num,
122
+ ).images
123
+
124
+ if blended:
125
+ if control_strength<1.0:
126
+ raise gr.Error('Using blurred blending with control strength less than 1.0 is not allowed')
127
+ blended_image=[]
128
+ # blur, you can adjust the parameters for better performance
129
+ mask_blurred = cv2.GaussianBlur(mask*255, (21, 21), 0)/255
130
+ mask_blurred = mask_blurred[:,:,np.newaxis]
131
+ mask = 1-(1-mask) * (1-mask_blurred)
132
+ for image_i in image:
133
+ image_np=np.array(image_i)
134
+ image_pasted=original_image * (1-mask) + image_np*mask
135
+
136
+ image_pasted=image_pasted.astype(image_np.dtype)
137
+ blended_image.append(Image.fromarray(image_pasted))
138
+
139
+ image=blended_image
140
+
141
+ return image
142
+
143
+ block = gr.Blocks(
144
+ theme=gr.themes.Soft(
145
+ radius_size=gr.themes.sizes.radius_none,
146
+ text_size=gr.themes.sizes.text_md
147
+ )
148
+ ).queue()
149
+ with block:
150
+ with gr.Row():
151
+ with gr.Column():
152
+
153
+ gr.HTML(f"""
154
+ <div style="text-align: center;">
155
+ <h1>BrushNet: A Plug-and-Play Image Inpainting Model with Decomposed Dual-Branch Diffusion</h1>
156
+ <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
157
+ <a href=""></a>
158
+ <a href='https://tencentarc.github.io/BrushNet/'><img src='https://img.shields.io/badge/Project_Page-BrushNet-green' alt='Project Page'></a>
159
+ <a href='https://arxiv.org/abs/2403.06976'><img src='https://img.shields.io/badge/Paper-Arxiv-blue'></a>
160
+ </div>
161
+ </br>
162
+ </div>
163
+ """)
164
+
165
+
166
+ with gr.Accordion(label="🧭 Instructions:", open=True, elem_id="accordion"):
167
+ with gr.Row(equal_height=True):
168
+ gr.Markdown("""
169
+ - ⭐️ <b>step1: </b>Upload or select one image from Example
170
+ - ⭐️ <b>step2: </b>Click on Input-image to select the object to be retained (or upload a white-black Mask image, in which white color indicates the region you want to keep unchanged). You can tick the 'Invert Mask' box to switch region unchanged and change.
171
+ - ⭐️ <b>step3: </b>Input prompt for generating new contents
172
+ - ⭐️ <b>step4: </b>Click Run button
173
+ """)
174
+ with gr.Row():
175
+ with gr.Column():
176
+ with gr.Column(elem_id="Input"):
177
+ with gr.Row():
178
+ with gr.Tabs(elem_classes=["feedback"]):
179
+ with gr.TabItem("Input Image"):
180
+ input_image = gr.Image(type="numpy", label="input",scale=2, height=640)
181
+ original_image = gr.State(value=None,label="index")
182
+ original_mask = gr.State(value=None)
183
+ selected_points = gr.State([],label="select points")
184
+ with gr.Row(elem_id="Seg"):
185
+ radio = gr.Radio(['foreground', 'background'], label='Click to seg: ', value='foreground',scale=2)
186
+ undo_button = gr.Button('Undo seg', elem_id="btnSEG",scale=1)
187
+ prompt = gr.Textbox(label="Prompt", placeholder="Please input your prompt",value='',lines=1)
188
+ negative_prompt = gr.Text(
189
+ label="Negative Prompt",
190
+ max_lines=5,
191
+ placeholder="Please input your negative prompt",
192
+ value='ugly, low quality',lines=1
193
+ )
194
+ with gr.Group():
195
+ with gr.Row():
196
+ blending = gr.Checkbox(label="Blurred Blending", value=False)
197
+ invert_mask = gr.Checkbox(label="Invert Mask", value=True)
198
+ run_button = gr.Button("Run",elem_id="btn")
199
+
200
+ with gr.Accordion("More input params (highly-recommended)", open=False, elem_id="accordion1"):
201
+ control_strength = gr.Slider(
202
+ label="Control Strength: ", show_label=True, minimum=0, maximum=1.1, value=1, step=0.01
203
+ )
204
+ with gr.Group():
205
+ seed = gr.Slider(
206
+ label="Seed: ", minimum=0, maximum=2147483647, step=1, value=551793204
207
+ )
208
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=False)
209
+
210
+ with gr.Group():
211
+ with gr.Row():
212
+ guidance_scale = gr.Slider(
213
+ label="Guidance scale",
214
+ minimum=1,
215
+ maximum=12,
216
+ step=0.1,
217
+ value=12,
218
+ )
219
+ num_inference_steps = gr.Slider(
220
+ label="Number of inference steps",
221
+ minimum=1,
222
+ maximum=50,
223
+ step=1,
224
+ value=50,
225
+ )
226
+ with gr.Row(elem_id="Image"):
227
+ with gr.Tabs(elem_classes=["feedback1"]):
228
+ with gr.TabItem("User-specified Mask Image (Optional)"):
229
+ input_mask = gr.Image(type="numpy", label="Mask Image", height=640)
230
+
231
+ with gr.Column():
232
+ with gr.Tabs(elem_classes=["feedback"]):
233
+ with gr.TabItem("Outputs"):
234
+ result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery", preview=True)
235
+ with gr.Row():
236
+ def process_example(input_image, prompt, input_mask, original_image, selected_points,result_gallery): #
237
+ return input_image, prompt, input_mask, original_image, [], result_gallery
238
+ example = gr.Examples(
239
+ label="Input Example",
240
+ examples=image_examples,
241
+ inputs=[input_image, prompt, input_mask, original_image, selected_points,result_gallery],
242
+ outputs=[input_image, prompt, input_mask, original_image, selected_points],
243
+ fn=process_example,
244
+ run_on_click=True,
245
+ examples_per_page=10
246
+ )
247
+
248
+ # once user upload an image, the original image is stored in `original_image`
249
+ def store_img(img):
250
+ # image upload is too slow
251
+ if min(img.shape[0], img.shape[1]) > 512:
252
+ img = resize_image(img, 512)
253
+ if max(img.shape[0], img.shape[1])*1.0/min(img.shape[0], img.shape[1])>2.0:
254
+ raise gr.Error('image aspect ratio cannot be larger than 2.0')
255
+ return img, img, [], None # when new image is uploaded, `selected_points` should be empty
256
+
257
+ input_image.upload(
258
+ store_img,
259
+ [input_image],
260
+ [input_image, original_image, selected_points]
261
+ )
262
+
263
+ # user click the image to get points, and show the points on the image
264
+ def segmentation(img, sel_pix):
265
+ # online show seg mask
266
+ points = []
267
+ labels = []
268
+ for p, l in sel_pix:
269
+ points.append(p)
270
+ labels.append(l)
271
+ mobile_predictor.set_image(img if isinstance(img, np.ndarray) else np.array(img))
272
+ with torch.no_grad():
273
+ masks, _, _ = mobile_predictor.predict(point_coords=np.array(points), point_labels=np.array(labels), multimask_output=False)
274
+
275
+ output_mask = np.ones((masks.shape[1], masks.shape[2], 3))*255
276
+ for i in range(3):
277
+ output_mask[masks[0] == True, i] = 0.0
278
+
279
+ mask_all = np.ones((masks.shape[1], masks.shape[2], 3))
280
+ color_mask = np.random.random((1, 3)).tolist()[0]
281
+ for i in range(3):
282
+ mask_all[masks[0] == True, i] = color_mask[i]
283
+ masked_img = img / 255 * 0.3 + mask_all * 0.7
284
+ masked_img = masked_img*255
285
+ ## draw points
286
+ for point, label in sel_pix:
287
+ cv2.drawMarker(masked_img, point, colors[label], markerType=markers[label], markerSize=20, thickness=5)
288
+ return masked_img, output_mask
289
+
290
+ def get_point(img, sel_pix, point_type, evt: gr.SelectData):
291
+ if point_type == 'foreground':
292
+ sel_pix.append((evt.index, 1)) # append the foreground_point
293
+ elif point_type == 'background':
294
+ sel_pix.append((evt.index, 0)) # append the background_point
295
+ else:
296
+ sel_pix.append((evt.index, 1)) # default foreground_point
297
+
298
+ if isinstance(img, int):
299
+ image_name = image_examples[img][0]
300
+ img = cv2.imread(image_name)
301
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
302
+
303
+ # online show seg mask
304
+ masked_img, output_mask = segmentation(img, sel_pix)
305
+ return masked_img.astype(np.uint8), output_mask
306
+
307
+ input_image.select(
308
+ get_point,
309
+ [original_image, selected_points, radio],
310
+ [input_image, original_mask],
311
+ )
312
+
313
+ # undo the selected point
314
+ def undo_points(orig_img, sel_pix):
315
+ # draw points
316
+ output_mask = None
317
+ if len(sel_pix) != 0:
318
+ if isinstance(orig_img, int): # if orig_img is int, the image if select from examples
319
+ temp = cv2.imread(image_examples[orig_img][0])
320
+ temp = cv2.cvtColor(temp, cv2.COLOR_BGR2RGB)
321
+ else:
322
+ temp = orig_img.copy()
323
+ sel_pix.pop()
324
+ # online show seg mask
325
+ if len(sel_pix) !=0:
326
+ temp, output_mask = segmentation(temp, sel_pix)
327
+ return temp.astype(np.uint8), output_mask
328
+ else:
329
+ gr.Error("Nothing to Undo")
330
+
331
+ undo_button.click(
332
+ undo_points,
333
+ [original_image, selected_points],
334
+ [input_image, original_mask]
335
+ )
336
+
337
+ ips=[input_image, original_image, original_mask, input_mask, selected_points, prompt, negative_prompt, blending, invert_mask, control_strength, seed, randomize_seed, guidance_scale, num_inference_steps]
338
+ run_button.click(fn=process, inputs=ips, outputs=[result_gallery])
339
+
340
+
341
+ block.launch()
data/ckpt/realisticVisionV60B1_v51VAE/feature_extractor/preprocessor_config.json ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "crop_size": {
3
+ "height": 224,
4
+ "width": 224
5
+ },
6
+ "do_center_crop": true,
7
+ "do_convert_rgb": true,
8
+ "do_normalize": true,
9
+ "do_rescale": true,
10
+ "do_resize": true,
11
+ "image_mean": [
12
+ 0.48145466,
13
+ 0.4578275,
14
+ 0.40821073
15
+ ],
16
+ "image_processor_type": "CLIPFeatureExtractor",
17
+ "image_std": [
18
+ 0.26862954,
19
+ 0.26130258,
20
+ 0.27577711
21
+ ],
22
+ "resample": 3,
23
+ "rescale_factor": 0.00392156862745098,
24
+ "size": {
25
+ "shortest_edge": 224
26
+ }
27
+ }
data/ckpt/realisticVisionV60B1_v51VAE/model_index.json ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "StableDiffusionPipeline",
3
+ "_diffusers_version": "0.26.0.dev0",
4
+ "feature_extractor": [
5
+ "transformers",
6
+ "CLIPFeatureExtractor"
7
+ ],
8
+ "image_encoder": [
9
+ null,
10
+ null
11
+ ],
12
+ "requires_safety_checker": true,
13
+ "safety_checker": [
14
+ "stable_diffusion",
15
+ "StableDiffusionSafetyChecker"
16
+ ],
17
+ "scheduler": [
18
+ "diffusers",
19
+ "PNDMScheduler"
20
+ ],
21
+ "text_encoder": [
22
+ "transformers",
23
+ "CLIPTextModel"
24
+ ],
25
+ "tokenizer": [
26
+ "transformers",
27
+ "CLIPTokenizer"
28
+ ],
29
+ "unet": [
30
+ "diffusers",
31
+ "UNet2DConditionModel"
32
+ ],
33
+ "vae": [
34
+ "diffusers",
35
+ "AutoencoderKL"
36
+ ]
37
+ }
data/ckpt/realisticVisionV60B1_v51VAE/safety_checker/config.json ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "CompVis/stable-diffusion-safety-checker",
3
+ "architectures": [
4
+ "StableDiffusionSafetyChecker"
5
+ ],
6
+ "initializer_factor": 1.0,
7
+ "logit_scale_init_value": 2.6592,
8
+ "model_type": "clip",
9
+ "projection_dim": 768,
10
+ "text_config": {
11
+ "dropout": 0.0,
12
+ "hidden_size": 768,
13
+ "intermediate_size": 3072,
14
+ "model_type": "clip_text_model",
15
+ "num_attention_heads": 12
16
+ },
17
+ "torch_dtype": "float32",
18
+ "transformers_version": "4.37.2",
19
+ "vision_config": {
20
+ "dropout": 0.0,
21
+ "hidden_size": 1024,
22
+ "intermediate_size": 4096,
23
+ "model_type": "clip_vision_model",
24
+ "num_attention_heads": 16,
25
+ "num_hidden_layers": 24,
26
+ "patch_size": 14
27
+ }
28
+ }
data/ckpt/realisticVisionV60B1_v51VAE/safety_checker/pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:afa7ebf10b23008ecbb81111e9cc8818443e1302a0defaa0a6cbd8cfe310b278
3
+ size 1216059369
data/ckpt/realisticVisionV60B1_v51VAE/scheduler/scheduler_config.json ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "PNDMScheduler",
3
+ "_diffusers_version": "0.26.0.dev0",
4
+ "beta_end": 0.012,
5
+ "beta_schedule": "scaled_linear",
6
+ "beta_start": 0.00085,
7
+ "clip_sample": false,
8
+ "num_train_timesteps": 1000,
9
+ "prediction_type": "epsilon",
10
+ "set_alpha_to_one": false,
11
+ "skip_prk_steps": true,
12
+ "steps_offset": 1,
13
+ "timestep_spacing": "leading",
14
+ "trained_betas": null
15
+ }
data/ckpt/realisticVisionV60B1_v51VAE/text_encoder/config.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "CLIPTextModel"
4
+ ],
5
+ "attention_dropout": 0.0,
6
+ "bos_token_id": 0,
7
+ "dropout": 0.0,
8
+ "eos_token_id": 2,
9
+ "hidden_act": "quick_gelu",
10
+ "hidden_size": 768,
11
+ "initializer_factor": 1.0,
12
+ "initializer_range": 0.02,
13
+ "intermediate_size": 3072,
14
+ "layer_norm_eps": 1e-05,
15
+ "max_position_embeddings": 77,
16
+ "model_type": "clip_text_model",
17
+ "num_attention_heads": 12,
18
+ "num_hidden_layers": 12,
19
+ "pad_token_id": 1,
20
+ "projection_dim": 768,
21
+ "torch_dtype": "float32",
22
+ "transformers_version": "4.37.2",
23
+ "vocab_size": 49408
24
+ }
data/ckpt/realisticVisionV60B1_v51VAE/text_encoder/pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ccf6d26500a883b3e7227d2f10292bee7fe078f26cb1afef46eee26252805f0b
3
+ size 492304313
data/ckpt/realisticVisionV60B1_v51VAE/tokenizer/merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
data/ckpt/realisticVisionV60B1_v51VAE/tokenizer/special_tokens_map.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<|startoftext|>",
4
+ "lstrip": false,
5
+ "normalized": true,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "<|endoftext|>",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": {
17
+ "content": "<|endoftext|>",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ },
23
+ "unk_token": {
24
+ "content": "<|endoftext|>",
25
+ "lstrip": false,
26
+ "normalized": false,
27
+ "rstrip": false,
28
+ "single_word": false
29
+ }
30
+ }
data/ckpt/realisticVisionV60B1_v51VAE/tokenizer/tokenizer_config.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "added_tokens_decoder": {
4
+ "49406": {
5
+ "content": "<|startoftext|>",
6
+ "lstrip": false,
7
+ "normalized": true,
8
+ "rstrip": false,
9
+ "single_word": false,
10
+ "special": true
11
+ },
12
+ "49407": {
13
+ "content": "<|endoftext|>",
14
+ "lstrip": false,
15
+ "normalized": false,
16
+ "rstrip": false,
17
+ "single_word": false,
18
+ "special": true
19
+ }
20
+ },
21
+ "bos_token": "<|startoftext|>",
22
+ "clean_up_tokenization_spaces": true,
23
+ "do_lower_case": true,
24
+ "eos_token": "<|endoftext|>",
25
+ "errors": "replace",
26
+ "model_max_length": 77,
27
+ "pad_token": "<|endoftext|>",
28
+ "tokenizer_class": "CLIPTokenizer",
29
+ "unk_token": "<|endoftext|>"
30
+ }
data/ckpt/realisticVisionV60B1_v51VAE/tokenizer/vocab.json ADDED
The diff for this file is too large to render. See raw diff
 
data/ckpt/realisticVisionV60B1_v51VAE/unet/config.json ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "UNet2DConditionModel",
3
+ "_diffusers_version": "0.26.0.dev0",
4
+ "act_fn": "silu",
5
+ "addition_embed_type": null,
6
+ "addition_embed_type_num_heads": 64,
7
+ "addition_time_embed_dim": null,
8
+ "attention_head_dim": 8,
9
+ "attention_type": "default",
10
+ "block_out_channels": [
11
+ 320,
12
+ 640,
13
+ 1280,
14
+ 1280
15
+ ],
16
+ "center_input_sample": false,
17
+ "class_embed_type": null,
18
+ "class_embeddings_concat": false,
19
+ "conv_in_kernel": 3,
20
+ "conv_out_kernel": 3,
21
+ "cross_attention_dim": 768,
22
+ "cross_attention_norm": null,
23
+ "down_block_types": [
24
+ "CrossAttnDownBlock2D",
25
+ "CrossAttnDownBlock2D",
26
+ "CrossAttnDownBlock2D",
27
+ "DownBlock2D"
28
+ ],
29
+ "downsample_padding": 1,
30
+ "dropout": 0.0,
31
+ "dual_cross_attention": false,
32
+ "encoder_hid_dim": null,
33
+ "encoder_hid_dim_type": null,
34
+ "flip_sin_to_cos": true,
35
+ "freq_shift": 0,
36
+ "in_channels": 4,
37
+ "layers_per_block": 2,
38
+ "mid_block_only_cross_attention": null,
39
+ "mid_block_scale_factor": 1,
40
+ "mid_block_type": "UNetMidBlock2DCrossAttn",
41
+ "norm_eps": 1e-05,
42
+ "norm_num_groups": 32,
43
+ "num_attention_heads": null,
44
+ "num_class_embeds": null,
45
+ "only_cross_attention": false,
46
+ "out_channels": 4,
47
+ "projection_class_embeddings_input_dim": null,
48
+ "resnet_out_scale_factor": 1.0,
49
+ "resnet_skip_time_act": false,
50
+ "resnet_time_scale_shift": "default",
51
+ "reverse_transformer_layers_per_block": null,
52
+ "sample_size": 64,
53
+ "time_cond_proj_dim": null,
54
+ "time_embedding_act_fn": null,
55
+ "time_embedding_dim": null,
56
+ "time_embedding_type": "positional",
57
+ "timestep_post_act": null,
58
+ "transformer_layers_per_block": 1,
59
+ "up_block_types": [
60
+ "UpBlock2D",
61
+ "CrossAttnUpBlock2D",
62
+ "CrossAttnUpBlock2D",
63
+ "CrossAttnUpBlock2D"
64
+ ],
65
+ "upcast_attention": false,
66
+ "use_linear_projection": false
67
+ }
data/ckpt/realisticVisionV60B1_v51VAE/unet/diffusion_pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c8dac118ff5749ca65298f1546ec23d3fec876cded41f9559fcd863e30d969f4
3
+ size 3438354725
data/ckpt/realisticVisionV60B1_v51VAE/vae/config.json ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "AutoencoderKL",
3
+ "_diffusers_version": "0.26.0.dev0",
4
+ "act_fn": "silu",
5
+ "block_out_channels": [
6
+ 128,
7
+ 256,
8
+ 512,
9
+ 512
10
+ ],
11
+ "down_block_types": [
12
+ "DownEncoderBlock2D",
13
+ "DownEncoderBlock2D",
14
+ "DownEncoderBlock2D",
15
+ "DownEncoderBlock2D"
16
+ ],
17
+ "force_upcast": true,
18
+ "in_channels": 3,
19
+ "latent_channels": 4,
20
+ "layers_per_block": 2,
21
+ "norm_num_groups": 32,
22
+ "out_channels": 3,
23
+ "sample_size": 512,
24
+ "scaling_factor": 0.18215,
25
+ "up_block_types": [
26
+ "UpDecoderBlock2D",
27
+ "UpDecoderBlock2D",
28
+ "UpDecoderBlock2D",
29
+ "UpDecoderBlock2D"
30
+ ]
31
+ }
data/ckpt/realisticVisionV60B1_v51VAE/vae/diffusion_pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9428a7d3cab924439bbba9f2c3e5128d466f9dffd9c12ed7098ae288a5bfe8bd
3
+ size 334707473
data/ckpt/sam_vit_h_4b8939.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a7bf3b02f3ebf1267aba913ff637d9a2d5c33d3173bb679e46d9f338c26f262e
3
+ size 2564550879
data/ckpt/segmentation_mask_brushnet_ckpt/config.json ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "BrushNetModel",
3
+ "_diffusers_version": "0.27.0.dev0",
4
+ "_name_or_path": "runs/logs/brushnet_segmask/checkpoint-550000",
5
+ "act_fn": "silu",
6
+ "addition_embed_type": null,
7
+ "addition_embed_type_num_heads": 64,
8
+ "addition_time_embed_dim": null,
9
+ "attention_head_dim": 8,
10
+ "block_out_channels": [
11
+ 320,
12
+ 640,
13
+ 1280,
14
+ 1280
15
+ ],
16
+ "brushnet_conditioning_channel_order": "rgb",
17
+ "class_embed_type": null,
18
+ "conditioning_channels": 5,
19
+ "conditioning_embedding_out_channels": [
20
+ 16,
21
+ 32,
22
+ 96,
23
+ 256
24
+ ],
25
+ "cross_attention_dim": 768,
26
+ "down_block_types": [
27
+ "DownBlock2D",
28
+ "DownBlock2D",
29
+ "DownBlock2D",
30
+ "DownBlock2D"
31
+ ],
32
+ "downsample_padding": 1,
33
+ "encoder_hid_dim": null,
34
+ "encoder_hid_dim_type": null,
35
+ "flip_sin_to_cos": true,
36
+ "freq_shift": 0,
37
+ "global_pool_conditions": false,
38
+ "in_channels": 4,
39
+ "layers_per_block": 2,
40
+ "mid_block_scale_factor": 1,
41
+ "mid_block_type": "MidBlock2D",
42
+ "norm_eps": 1e-05,
43
+ "norm_num_groups": 32,
44
+ "num_attention_heads": null,
45
+ "num_class_embeds": null,
46
+ "only_cross_attention": false,
47
+ "projection_class_embeddings_input_dim": null,
48
+ "resnet_time_scale_shift": "default",
49
+ "transformer_layers_per_block": 1,
50
+ "up_block_types": [
51
+ "UpBlock2D",
52
+ "UpBlock2D",
53
+ "UpBlock2D",
54
+ "UpBlock2D"
55
+ ],
56
+ "upcast_attention": false,
57
+ "use_linear_projection": false
58
+ }
data/ckpt/segmentation_mask_brushnet_ckpt/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:066793867083310c25b0d57d300b1ecdc02cb708a155060f020be64afb2ae9c3
3
+ size 2475354520
examples/brushnet/src/example_1.jpg ADDED
examples/brushnet/src/example_1_mask.jpg ADDED
examples/brushnet/src/example_1_result.png ADDED
examples/brushnet/src/example_3.jpg ADDED
examples/brushnet/src/example_3_mask.jpg ADDED
examples/brushnet/src/example_3_result.png ADDED
examples/brushnet/src/example_4.jpeg ADDED
examples/brushnet/src/example_4_mask.jpg ADDED
examples/brushnet/src/example_4_result.png ADDED
examples/brushnet/src/example_5.jpg ADDED
examples/brushnet/src/example_5_mask.jpg ADDED
examples/brushnet/src/example_5_result.png ADDED
examples/brushnet/src/test_image.jpg ADDED
examples/brushnet/src/test_mask.jpg ADDED
examples/brushnet/src/test_result.png ADDED
mask.png ADDED
requirements.txt ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ torchaudio
4
+ transformers>=4.25.1
5
+ gradio==3.50.0
6
+ ftfy
7
+ tensorboard
8
+ datasets
9
+ Pillow==9.5.0
10
+ opencv-python
11
+ imgaug
12
+ accelerate==0.20.3
13
+ image-reward
14
+ hpsv2
15
+ torchmetrics
16
+ open-clip-torch
17
+ clip
18
+ segment_anything
19
+ git+https://github.com/TencentARC/BrushNet.git