Spaces:
Runtime error
Runtime error
Delete test.py
Browse files
test.py
DELETED
@@ -1,376 +0,0 @@
|
|
1 |
-
import argparse, os
|
2 |
-
import cv2
|
3 |
-
import torch
|
4 |
-
import numpy as np
|
5 |
-
import torchvision
|
6 |
-
from omegaconf import OmegaConf
|
7 |
-
from PIL import Image
|
8 |
-
from tqdm import tqdm, trange
|
9 |
-
from itertools import islice
|
10 |
-
from einops import rearrange
|
11 |
-
from torchvision.utils import make_grid
|
12 |
-
import time
|
13 |
-
from pytorch_lightning import seed_everything
|
14 |
-
from torch import autocast
|
15 |
-
from contextlib import nullcontext
|
16 |
-
|
17 |
-
from ldm.util import instantiate_from_config
|
18 |
-
from ldm.models.diffusion.ddim import DDIMSampler
|
19 |
-
from ldm.modules.diffusionmodules.openaimodel import clear_feature_dic,get_feature_dic
|
20 |
-
from ldm.models.seg_module import Segmodule
|
21 |
-
|
22 |
-
import numpy as np
|
23 |
-
|
24 |
-
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
|
25 |
-
|
26 |
-
def chunk(it, size):
|
27 |
-
it = iter(it)
|
28 |
-
return iter(lambda: tuple(islice(it, size)), ())
|
29 |
-
|
30 |
-
|
31 |
-
def numpy_to_pil(images):
|
32 |
-
"""
|
33 |
-
Convert a numpy image or a batch of images to a PIL image.
|
34 |
-
"""
|
35 |
-
if images.ndim == 3:
|
36 |
-
images = images[None, ...]
|
37 |
-
images = (images * 255).round().astype("uint8")
|
38 |
-
pil_images = [Image.fromarray(image) for image in images]
|
39 |
-
|
40 |
-
return pil_images
|
41 |
-
|
42 |
-
|
43 |
-
def load_model_from_config(config, ckpt, verbose=False):
|
44 |
-
print(f"Loading model from {ckpt}")
|
45 |
-
pl_sd = torch.load(ckpt, map_location="cpu")
|
46 |
-
if "global_step" in pl_sd:
|
47 |
-
print(f"Global Step: {pl_sd['global_step']}")
|
48 |
-
sd = pl_sd["state_dict"]
|
49 |
-
model = instantiate_from_config(config.model)
|
50 |
-
m, u = model.load_state_dict(sd, strict=False)
|
51 |
-
if len(m) > 0 and verbose:
|
52 |
-
print("missing keys:")
|
53 |
-
print(m)
|
54 |
-
if len(u) > 0 and verbose:
|
55 |
-
print("unexpected keys:")
|
56 |
-
print(u)
|
57 |
-
|
58 |
-
model.cuda()
|
59 |
-
model.eval()
|
60 |
-
return model
|
61 |
-
|
62 |
-
|
63 |
-
def put_watermark(img, wm_encoder=None):
|
64 |
-
if wm_encoder is not None:
|
65 |
-
img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
|
66 |
-
img = wm_encoder.encode(img, 'dwtDct')
|
67 |
-
img = Image.fromarray(img[:, :, ::-1])
|
68 |
-
return img
|
69 |
-
|
70 |
-
|
71 |
-
def load_replacement(x):
|
72 |
-
try:
|
73 |
-
hwc = x.shape
|
74 |
-
y = Image.open("assets/rick.jpeg").convert("RGB").resize((hwc[1], hwc[0]))
|
75 |
-
y = (np.array(y)/255.0).astype(x.dtype)
|
76 |
-
assert y.shape == x.shape
|
77 |
-
return y
|
78 |
-
except Exception:
|
79 |
-
return x
|
80 |
-
|
81 |
-
def plot_mask(img, masks, colors=None, alpha=0.8,indexlist=[0,1]) -> np.ndarray:
|
82 |
-
"""Visualize segmentation mask.
|
83 |
-
|
84 |
-
Parameters
|
85 |
-
----------
|
86 |
-
img: numpy.ndarray
|
87 |
-
Image with shape `(H, W, 3)`.
|
88 |
-
masks: numpy.ndarray
|
89 |
-
Binary images with shape `(N, H, W)`.
|
90 |
-
colors: numpy.ndarray
|
91 |
-
corlor for mask, shape `(N, 3)`.
|
92 |
-
if None, generate random color for mask
|
93 |
-
alpha: float, optional, default 0.5
|
94 |
-
Transparency of plotted mask
|
95 |
-
|
96 |
-
Returns
|
97 |
-
-------
|
98 |
-
numpy.ndarray
|
99 |
-
The image plotted with segmentation masks, shape `(H, W, 3)`
|
100 |
-
|
101 |
-
"""
|
102 |
-
H,W= masks.shape[0],masks.shape[1]
|
103 |
-
color_list=[[255,97,0],[128,42,42],[220,220,220],[255,153,18],[56,94,15],[127,255,212],[210,180,140],[221,160,221],[255,0,0],[255,128,0],[255,255,0],[128,255,0],[0,255,0],[0,255,128],[0,255,255],[0,128,255],[0,0,255],[128,0,255],[255,0,255],[255,0,128]]*6
|
104 |
-
final_color_list=[np.array([[i]*512]*512) for i in color_list]
|
105 |
-
|
106 |
-
background=np.ones(img.shape)*255
|
107 |
-
count=0
|
108 |
-
colors=final_color_list[indexlist[count]]
|
109 |
-
for mask, color in zip(masks, colors):
|
110 |
-
color=final_color_list[indexlist[count]]
|
111 |
-
mask = np.stack([mask, mask, mask], -1)
|
112 |
-
img = np.where(mask, img * (1 - alpha) + color * alpha,background*0.4+img*0.6 )
|
113 |
-
count+=1
|
114 |
-
return img.astype(np.uint8)
|
115 |
-
|
116 |
-
def main():
|
117 |
-
parser = argparse.ArgumentParser()
|
118 |
-
|
119 |
-
parser.add_argument(
|
120 |
-
"--prompt",
|
121 |
-
type=str,
|
122 |
-
nargs="?",
|
123 |
-
default="a photo of a lion on a mountain top at sunset",
|
124 |
-
help="the prompt to render"
|
125 |
-
)
|
126 |
-
parser.add_argument(
|
127 |
-
"--category",
|
128 |
-
type=str,
|
129 |
-
nargs="?",
|
130 |
-
default="lion",
|
131 |
-
help="the category to ground"
|
132 |
-
)
|
133 |
-
parser.add_argument(
|
134 |
-
"--outdir",
|
135 |
-
type=str,
|
136 |
-
nargs="?",
|
137 |
-
help="dir to write results to",
|
138 |
-
default="outputs/txt2img-samples"
|
139 |
-
)
|
140 |
-
parser.add_argument(
|
141 |
-
"--skip_grid",
|
142 |
-
action='store_true',
|
143 |
-
help="do not save a grid, only individual samples. Helpful when evaluating lots of samples",
|
144 |
-
)
|
145 |
-
parser.add_argument(
|
146 |
-
"--skip_save",
|
147 |
-
action='store_true',
|
148 |
-
help="do not save individual samples. For speed measurements.",
|
149 |
-
)
|
150 |
-
parser.add_argument(
|
151 |
-
"--ddim_steps",
|
152 |
-
type=int,
|
153 |
-
default=50,
|
154 |
-
help="number of ddim sampling steps",
|
155 |
-
)
|
156 |
-
parser.add_argument(
|
157 |
-
"--plms",
|
158 |
-
action='store_true',
|
159 |
-
help="use plms sampling",
|
160 |
-
)
|
161 |
-
parser.add_argument(
|
162 |
-
"--laion400m",
|
163 |
-
action='store_true',
|
164 |
-
help="uses the LAION400M model",
|
165 |
-
)
|
166 |
-
parser.add_argument(
|
167 |
-
"--fixed_code",
|
168 |
-
action='store_true',
|
169 |
-
help="if enabled, uses the same starting code across samples ",
|
170 |
-
)
|
171 |
-
parser.add_argument(
|
172 |
-
"--ddim_eta",
|
173 |
-
type=float,
|
174 |
-
default=0.0,
|
175 |
-
help="ddim eta (eta=0.0 corresponds to deterministic sampling",
|
176 |
-
)
|
177 |
-
parser.add_argument(
|
178 |
-
"--n_iter",
|
179 |
-
type=int,
|
180 |
-
default=1,
|
181 |
-
help="sample this often",
|
182 |
-
)
|
183 |
-
parser.add_argument(
|
184 |
-
"--H",
|
185 |
-
type=int,
|
186 |
-
default=512,
|
187 |
-
help="image height, in pixel space",
|
188 |
-
)
|
189 |
-
parser.add_argument(
|
190 |
-
"--W",
|
191 |
-
type=int,
|
192 |
-
default=512,
|
193 |
-
help="image width, in pixel space",
|
194 |
-
)
|
195 |
-
parser.add_argument(
|
196 |
-
"--C",
|
197 |
-
type=int,
|
198 |
-
default=4,
|
199 |
-
help="latent channels",
|
200 |
-
)
|
201 |
-
parser.add_argument(
|
202 |
-
"--f",
|
203 |
-
type=int,
|
204 |
-
default=8,
|
205 |
-
help="downsampling factor",
|
206 |
-
)
|
207 |
-
parser.add_argument(
|
208 |
-
"--n_samples",
|
209 |
-
type=int,
|
210 |
-
default=1,
|
211 |
-
help="how many samples to produce for each given prompt. A.k.a. batch size",
|
212 |
-
)
|
213 |
-
parser.add_argument(
|
214 |
-
"--n_rows",
|
215 |
-
type=int,
|
216 |
-
default=0,
|
217 |
-
help="rows in the grid (default: n_samples)",
|
218 |
-
)
|
219 |
-
parser.add_argument(
|
220 |
-
"--scale",
|
221 |
-
type=float,
|
222 |
-
default=7.5,
|
223 |
-
help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))",
|
224 |
-
)
|
225 |
-
parser.add_argument(
|
226 |
-
"--from-file",
|
227 |
-
type=str,
|
228 |
-
help="if specified, load prompts from this file",
|
229 |
-
)
|
230 |
-
parser.add_argument(
|
231 |
-
"--config",
|
232 |
-
type=str,
|
233 |
-
default="configs/stable-diffusion/v1-inference.yaml",
|
234 |
-
help="path to config which constructs model",
|
235 |
-
)
|
236 |
-
parser.add_argument(
|
237 |
-
"--sd_ckpt",
|
238 |
-
type=str,
|
239 |
-
default="stable_diffusion.ckpt",
|
240 |
-
help="path to checkpoint of stable diffusion model",
|
241 |
-
)
|
242 |
-
parser.add_argument(
|
243 |
-
"--grounding_ckpt",
|
244 |
-
type=str,
|
245 |
-
default="grounding_module.pth",
|
246 |
-
help="path to checkpoint of grounding module",
|
247 |
-
)
|
248 |
-
parser.add_argument(
|
249 |
-
"--seed",
|
250 |
-
type=int,
|
251 |
-
default=42,
|
252 |
-
help="the seed (for reproducible sampling)",
|
253 |
-
)
|
254 |
-
parser.add_argument(
|
255 |
-
"--precision",
|
256 |
-
type=str,
|
257 |
-
help="evaluate at this precision",
|
258 |
-
choices=["full", "autocast"],
|
259 |
-
default="autocast"
|
260 |
-
)
|
261 |
-
opt = parser.parse_args()
|
262 |
-
|
263 |
-
if opt.laion400m:
|
264 |
-
print("Falling back to LAION 400M model...")
|
265 |
-
opt.config = "configs/latent-diffusion/txt2img-1p4B-eval.yaml"
|
266 |
-
opt.ckpt = "models/ldm/text2img-large/model.ckpt"
|
267 |
-
opt.outdir = "outputs/txt2img-samples-laion400m"
|
268 |
-
|
269 |
-
seed_everything(opt.seed)
|
270 |
-
|
271 |
-
tic = time.time()
|
272 |
-
config = OmegaConf.load(f"{opt.config}")
|
273 |
-
model = load_model_from_config(config, f"{opt.sd_ckpt}")
|
274 |
-
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
275 |
-
model = model.to(device)
|
276 |
-
toc = time.time()
|
277 |
-
seg_module=Segmodule().to(device)
|
278 |
-
|
279 |
-
seg_module.load_state_dict(torch.load(opt.grounding_ckpt, map_location="cpu"), strict=True)
|
280 |
-
print('load time:',toc-tic)
|
281 |
-
sampler = DDIMSampler(model)
|
282 |
-
|
283 |
-
os.makedirs(opt.outdir, exist_ok=True)
|
284 |
-
outpath = opt.outdir
|
285 |
-
batch_size = opt.n_samples
|
286 |
-
precision_scope = autocast if opt.precision=="autocast" else nullcontext
|
287 |
-
with torch.no_grad():
|
288 |
-
with precision_scope("cuda"):
|
289 |
-
with model.ema_scope():
|
290 |
-
prompt = opt.prompt
|
291 |
-
text = opt.category
|
292 |
-
trainclass = text
|
293 |
-
if not opt.from_file:
|
294 |
-
assert prompt is not None
|
295 |
-
data = [batch_size * [prompt]]
|
296 |
-
|
297 |
-
else:
|
298 |
-
print(f"reading prompts from {opt.from_file}")
|
299 |
-
with open(opt.from_file, "r") as f:
|
300 |
-
data = f.read().splitlines()
|
301 |
-
data = list(chunk(data, batch_size))
|
302 |
-
|
303 |
-
sample_path = os.path.join(outpath, "samples")
|
304 |
-
os.makedirs(sample_path, exist_ok=True)
|
305 |
-
|
306 |
-
start_code = None
|
307 |
-
if opt.fixed_code:
|
308 |
-
print('start_code')
|
309 |
-
start_code = torch.randn([opt.n_samples, opt.C, opt.H // opt.f, opt.W // opt.f], device=device)
|
310 |
-
for n in trange(opt.n_iter, desc="Sampling"):
|
311 |
-
for prompts in tqdm(data, desc="data"):
|
312 |
-
clear_feature_dic()
|
313 |
-
uc = None
|
314 |
-
if opt.scale != 1.0:
|
315 |
-
uc = model.get_learned_conditioning(batch_size * [""])
|
316 |
-
if isinstance(prompts, tuple):
|
317 |
-
prompts = list(prompts)
|
318 |
-
|
319 |
-
c = model.get_learned_conditioning(prompts)
|
320 |
-
shape = [opt.C, opt.H // opt.f, opt.W // opt.f]
|
321 |
-
samples_ddim,_, _ = sampler.sample(S=opt.ddim_steps,
|
322 |
-
conditioning=c,
|
323 |
-
batch_size=opt.n_samples,
|
324 |
-
shape=shape,
|
325 |
-
verbose=False,
|
326 |
-
unconditional_guidance_scale=opt.scale,
|
327 |
-
unconditional_conditioning=uc,
|
328 |
-
eta=opt.ddim_eta,
|
329 |
-
x_T=start_code)
|
330 |
-
|
331 |
-
x_samples_ddim = model.decode_first_stage(samples_ddim)
|
332 |
-
diffusion_features = get_feature_dic()
|
333 |
-
|
334 |
-
|
335 |
-
x_sample = torch.clamp((x_samples_ddim[0] + 1.0) / 2.0, min=0.0, max=1.0)
|
336 |
-
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
|
337 |
-
|
338 |
-
Image.fromarray(x_sample.astype(np.uint8)).save("demo/demo.png")
|
339 |
-
img = x_sample.astype(np.uint8)
|
340 |
-
|
341 |
-
class_name = trainclass
|
342 |
-
|
343 |
-
query_text ="a "+prompt.split()[1]+" of a "+class_name
|
344 |
-
c_split = model.cond_stage_model.tokenizer.tokenize(query_text)
|
345 |
-
|
346 |
-
sen_text_embedding = model.get_learned_conditioning(query_text)
|
347 |
-
class_embedding = sen_text_embedding[:, 5:len(c_split)+1, :]
|
348 |
-
|
349 |
-
if class_embedding.size()[1] > 1:
|
350 |
-
class_embedding = torch.unsqueeze(class_embedding.mean(1), 1)
|
351 |
-
text_embedding = class_embedding
|
352 |
-
|
353 |
-
text_embedding = text_embedding.repeat(batch_size, 1, 1)
|
354 |
-
|
355 |
-
|
356 |
-
pred_seg_total = seg_module(diffusion_features, text_embedding)
|
357 |
-
|
358 |
-
|
359 |
-
pred_seg = torch.unsqueeze(pred_seg_total[0,0,:,:], 0).unsqueeze(0)
|
360 |
-
|
361 |
-
label_pred_prob = torch.sigmoid(pred_seg)
|
362 |
-
label_pred_mask = torch.zeros_like(label_pred_prob, dtype=torch.float32)
|
363 |
-
label_pred_mask[label_pred_prob > 0.5] = 1
|
364 |
-
annotation_pred = label_pred_mask[0][0].cpu()
|
365 |
-
|
366 |
-
mask = annotation_pred.numpy()
|
367 |
-
mask = np.expand_dims(mask, 0)
|
368 |
-
done_image_mask = plot_mask(img, mask, alpha=0.9, indexlist=[0])
|
369 |
-
cv2.imwrite(os.path.join("demo/demo_mask.png"), done_image_mask)
|
370 |
-
|
371 |
-
torchvision.utils.save_image(annotation_pred, os.path.join("demo/demo_segresult.png"), normalize=True, scale_each=True)
|
372 |
-
|
373 |
-
|
374 |
-
|
375 |
-
if __name__ == "__main__":
|
376 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|