Purple11 commited on
Commit
59cc705
·
1 Parent(s): c68eb58

Delete test.py

Browse files
Files changed (1) hide show
  1. test.py +0 -376
test.py DELETED
@@ -1,376 +0,0 @@
1
- import argparse, os
2
- import cv2
3
- import torch
4
- import numpy as np
5
- import torchvision
6
- from omegaconf import OmegaConf
7
- from PIL import Image
8
- from tqdm import tqdm, trange
9
- from itertools import islice
10
- from einops import rearrange
11
- from torchvision.utils import make_grid
12
- import time
13
- from pytorch_lightning import seed_everything
14
- from torch import autocast
15
- from contextlib import nullcontext
16
-
17
- from ldm.util import instantiate_from_config
18
- from ldm.models.diffusion.ddim import DDIMSampler
19
- from ldm.modules.diffusionmodules.openaimodel import clear_feature_dic,get_feature_dic
20
- from ldm.models.seg_module import Segmodule
21
-
22
- import numpy as np
23
-
24
- os.environ["CUDA_VISIBLE_DEVICES"] = "1"
25
-
26
- def chunk(it, size):
27
- it = iter(it)
28
- return iter(lambda: tuple(islice(it, size)), ())
29
-
30
-
31
- def numpy_to_pil(images):
32
- """
33
- Convert a numpy image or a batch of images to a PIL image.
34
- """
35
- if images.ndim == 3:
36
- images = images[None, ...]
37
- images = (images * 255).round().astype("uint8")
38
- pil_images = [Image.fromarray(image) for image in images]
39
-
40
- return pil_images
41
-
42
-
43
- def load_model_from_config(config, ckpt, verbose=False):
44
- print(f"Loading model from {ckpt}")
45
- pl_sd = torch.load(ckpt, map_location="cpu")
46
- if "global_step" in pl_sd:
47
- print(f"Global Step: {pl_sd['global_step']}")
48
- sd = pl_sd["state_dict"]
49
- model = instantiate_from_config(config.model)
50
- m, u = model.load_state_dict(sd, strict=False)
51
- if len(m) > 0 and verbose:
52
- print("missing keys:")
53
- print(m)
54
- if len(u) > 0 and verbose:
55
- print("unexpected keys:")
56
- print(u)
57
-
58
- model.cuda()
59
- model.eval()
60
- return model
61
-
62
-
63
- def put_watermark(img, wm_encoder=None):
64
- if wm_encoder is not None:
65
- img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
66
- img = wm_encoder.encode(img, 'dwtDct')
67
- img = Image.fromarray(img[:, :, ::-1])
68
- return img
69
-
70
-
71
- def load_replacement(x):
72
- try:
73
- hwc = x.shape
74
- y = Image.open("assets/rick.jpeg").convert("RGB").resize((hwc[1], hwc[0]))
75
- y = (np.array(y)/255.0).astype(x.dtype)
76
- assert y.shape == x.shape
77
- return y
78
- except Exception:
79
- return x
80
-
81
- def plot_mask(img, masks, colors=None, alpha=0.8,indexlist=[0,1]) -> np.ndarray:
82
- """Visualize segmentation mask.
83
-
84
- Parameters
85
- ----------
86
- img: numpy.ndarray
87
- Image with shape `(H, W, 3)`.
88
- masks: numpy.ndarray
89
- Binary images with shape `(N, H, W)`.
90
- colors: numpy.ndarray
91
- corlor for mask, shape `(N, 3)`.
92
- if None, generate random color for mask
93
- alpha: float, optional, default 0.5
94
- Transparency of plotted mask
95
-
96
- Returns
97
- -------
98
- numpy.ndarray
99
- The image plotted with segmentation masks, shape `(H, W, 3)`
100
-
101
- """
102
- H,W= masks.shape[0],masks.shape[1]
103
- color_list=[[255,97,0],[128,42,42],[220,220,220],[255,153,18],[56,94,15],[127,255,212],[210,180,140],[221,160,221],[255,0,0],[255,128,0],[255,255,0],[128,255,0],[0,255,0],[0,255,128],[0,255,255],[0,128,255],[0,0,255],[128,0,255],[255,0,255],[255,0,128]]*6
104
- final_color_list=[np.array([[i]*512]*512) for i in color_list]
105
-
106
- background=np.ones(img.shape)*255
107
- count=0
108
- colors=final_color_list[indexlist[count]]
109
- for mask, color in zip(masks, colors):
110
- color=final_color_list[indexlist[count]]
111
- mask = np.stack([mask, mask, mask], -1)
112
- img = np.where(mask, img * (1 - alpha) + color * alpha,background*0.4+img*0.6 )
113
- count+=1
114
- return img.astype(np.uint8)
115
-
116
- def main():
117
- parser = argparse.ArgumentParser()
118
-
119
- parser.add_argument(
120
- "--prompt",
121
- type=str,
122
- nargs="?",
123
- default="a photo of a lion on a mountain top at sunset",
124
- help="the prompt to render"
125
- )
126
- parser.add_argument(
127
- "--category",
128
- type=str,
129
- nargs="?",
130
- default="lion",
131
- help="the category to ground"
132
- )
133
- parser.add_argument(
134
- "--outdir",
135
- type=str,
136
- nargs="?",
137
- help="dir to write results to",
138
- default="outputs/txt2img-samples"
139
- )
140
- parser.add_argument(
141
- "--skip_grid",
142
- action='store_true',
143
- help="do not save a grid, only individual samples. Helpful when evaluating lots of samples",
144
- )
145
- parser.add_argument(
146
- "--skip_save",
147
- action='store_true',
148
- help="do not save individual samples. For speed measurements.",
149
- )
150
- parser.add_argument(
151
- "--ddim_steps",
152
- type=int,
153
- default=50,
154
- help="number of ddim sampling steps",
155
- )
156
- parser.add_argument(
157
- "--plms",
158
- action='store_true',
159
- help="use plms sampling",
160
- )
161
- parser.add_argument(
162
- "--laion400m",
163
- action='store_true',
164
- help="uses the LAION400M model",
165
- )
166
- parser.add_argument(
167
- "--fixed_code",
168
- action='store_true',
169
- help="if enabled, uses the same starting code across samples ",
170
- )
171
- parser.add_argument(
172
- "--ddim_eta",
173
- type=float,
174
- default=0.0,
175
- help="ddim eta (eta=0.0 corresponds to deterministic sampling",
176
- )
177
- parser.add_argument(
178
- "--n_iter",
179
- type=int,
180
- default=1,
181
- help="sample this often",
182
- )
183
- parser.add_argument(
184
- "--H",
185
- type=int,
186
- default=512,
187
- help="image height, in pixel space",
188
- )
189
- parser.add_argument(
190
- "--W",
191
- type=int,
192
- default=512,
193
- help="image width, in pixel space",
194
- )
195
- parser.add_argument(
196
- "--C",
197
- type=int,
198
- default=4,
199
- help="latent channels",
200
- )
201
- parser.add_argument(
202
- "--f",
203
- type=int,
204
- default=8,
205
- help="downsampling factor",
206
- )
207
- parser.add_argument(
208
- "--n_samples",
209
- type=int,
210
- default=1,
211
- help="how many samples to produce for each given prompt. A.k.a. batch size",
212
- )
213
- parser.add_argument(
214
- "--n_rows",
215
- type=int,
216
- default=0,
217
- help="rows in the grid (default: n_samples)",
218
- )
219
- parser.add_argument(
220
- "--scale",
221
- type=float,
222
- default=7.5,
223
- help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))",
224
- )
225
- parser.add_argument(
226
- "--from-file",
227
- type=str,
228
- help="if specified, load prompts from this file",
229
- )
230
- parser.add_argument(
231
- "--config",
232
- type=str,
233
- default="configs/stable-diffusion/v1-inference.yaml",
234
- help="path to config which constructs model",
235
- )
236
- parser.add_argument(
237
- "--sd_ckpt",
238
- type=str,
239
- default="stable_diffusion.ckpt",
240
- help="path to checkpoint of stable diffusion model",
241
- )
242
- parser.add_argument(
243
- "--grounding_ckpt",
244
- type=str,
245
- default="grounding_module.pth",
246
- help="path to checkpoint of grounding module",
247
- )
248
- parser.add_argument(
249
- "--seed",
250
- type=int,
251
- default=42,
252
- help="the seed (for reproducible sampling)",
253
- )
254
- parser.add_argument(
255
- "--precision",
256
- type=str,
257
- help="evaluate at this precision",
258
- choices=["full", "autocast"],
259
- default="autocast"
260
- )
261
- opt = parser.parse_args()
262
-
263
- if opt.laion400m:
264
- print("Falling back to LAION 400M model...")
265
- opt.config = "configs/latent-diffusion/txt2img-1p4B-eval.yaml"
266
- opt.ckpt = "models/ldm/text2img-large/model.ckpt"
267
- opt.outdir = "outputs/txt2img-samples-laion400m"
268
-
269
- seed_everything(opt.seed)
270
-
271
- tic = time.time()
272
- config = OmegaConf.load(f"{opt.config}")
273
- model = load_model_from_config(config, f"{opt.sd_ckpt}")
274
- device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
275
- model = model.to(device)
276
- toc = time.time()
277
- seg_module=Segmodule().to(device)
278
-
279
- seg_module.load_state_dict(torch.load(opt.grounding_ckpt, map_location="cpu"), strict=True)
280
- print('load time:',toc-tic)
281
- sampler = DDIMSampler(model)
282
-
283
- os.makedirs(opt.outdir, exist_ok=True)
284
- outpath = opt.outdir
285
- batch_size = opt.n_samples
286
- precision_scope = autocast if opt.precision=="autocast" else nullcontext
287
- with torch.no_grad():
288
- with precision_scope("cuda"):
289
- with model.ema_scope():
290
- prompt = opt.prompt
291
- text = opt.category
292
- trainclass = text
293
- if not opt.from_file:
294
- assert prompt is not None
295
- data = [batch_size * [prompt]]
296
-
297
- else:
298
- print(f"reading prompts from {opt.from_file}")
299
- with open(opt.from_file, "r") as f:
300
- data = f.read().splitlines()
301
- data = list(chunk(data, batch_size))
302
-
303
- sample_path = os.path.join(outpath, "samples")
304
- os.makedirs(sample_path, exist_ok=True)
305
-
306
- start_code = None
307
- if opt.fixed_code:
308
- print('start_code')
309
- start_code = torch.randn([opt.n_samples, opt.C, opt.H // opt.f, opt.W // opt.f], device=device)
310
- for n in trange(opt.n_iter, desc="Sampling"):
311
- for prompts in tqdm(data, desc="data"):
312
- clear_feature_dic()
313
- uc = None
314
- if opt.scale != 1.0:
315
- uc = model.get_learned_conditioning(batch_size * [""])
316
- if isinstance(prompts, tuple):
317
- prompts = list(prompts)
318
-
319
- c = model.get_learned_conditioning(prompts)
320
- shape = [opt.C, opt.H // opt.f, opt.W // opt.f]
321
- samples_ddim,_, _ = sampler.sample(S=opt.ddim_steps,
322
- conditioning=c,
323
- batch_size=opt.n_samples,
324
- shape=shape,
325
- verbose=False,
326
- unconditional_guidance_scale=opt.scale,
327
- unconditional_conditioning=uc,
328
- eta=opt.ddim_eta,
329
- x_T=start_code)
330
-
331
- x_samples_ddim = model.decode_first_stage(samples_ddim)
332
- diffusion_features = get_feature_dic()
333
-
334
-
335
- x_sample = torch.clamp((x_samples_ddim[0] + 1.0) / 2.0, min=0.0, max=1.0)
336
- x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
337
-
338
- Image.fromarray(x_sample.astype(np.uint8)).save("demo/demo.png")
339
- img = x_sample.astype(np.uint8)
340
-
341
- class_name = trainclass
342
-
343
- query_text ="a "+prompt.split()[1]+" of a "+class_name
344
- c_split = model.cond_stage_model.tokenizer.tokenize(query_text)
345
-
346
- sen_text_embedding = model.get_learned_conditioning(query_text)
347
- class_embedding = sen_text_embedding[:, 5:len(c_split)+1, :]
348
-
349
- if class_embedding.size()[1] > 1:
350
- class_embedding = torch.unsqueeze(class_embedding.mean(1), 1)
351
- text_embedding = class_embedding
352
-
353
- text_embedding = text_embedding.repeat(batch_size, 1, 1)
354
-
355
-
356
- pred_seg_total = seg_module(diffusion_features, text_embedding)
357
-
358
-
359
- pred_seg = torch.unsqueeze(pred_seg_total[0,0,:,:], 0).unsqueeze(0)
360
-
361
- label_pred_prob = torch.sigmoid(pred_seg)
362
- label_pred_mask = torch.zeros_like(label_pred_prob, dtype=torch.float32)
363
- label_pred_mask[label_pred_prob > 0.5] = 1
364
- annotation_pred = label_pred_mask[0][0].cpu()
365
-
366
- mask = annotation_pred.numpy()
367
- mask = np.expand_dims(mask, 0)
368
- done_image_mask = plot_mask(img, mask, alpha=0.9, indexlist=[0])
369
- cv2.imwrite(os.path.join("demo/demo_mask.png"), done_image_mask)
370
-
371
- torchvision.utils.save_image(annotation_pred, os.path.join("demo/demo_segresult.png"), normalize=True, scale_each=True)
372
-
373
-
374
-
375
- if __name__ == "__main__":
376
- main()