File size: 6,148 Bytes
b122459 9226937 b122459 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
import os
from tqdm.auto import tqdm
from PIL import Image
import torch as T
import transformers, diffusers
from mgie_llava import LlavaLlamaForCausalLM_
from llava.conversation import conv_templates
from llava.model import *
import json
def read_json(file_path):
with open(file_path, 'r', encoding='utf-8') as file:
data = json.load(file)
return data
def write_json(file_path, data):
with open(file_path, 'w', encoding='utf-8') as file:
json.dump(data, file, ensure_ascii=False, indent=4)
def crop_resize(f, sz=512):
w, h = f.size
if w>h:
p = (w-h)//2
f = f.crop([p, 0, p+h, h])
elif h>w:
p = (h-w)//2
f = f.crop([0, p, w, p+w])
f = f.resize([sz, sz])
return f
def remove_alter(s): # hack expressive instruction
if 'ASSISTANT:' in s: s = s[s.index('ASSISTANT:')+10:].strip()
if '</s>' in s: s = s[:s.index('</s>')].strip()
if 'alternative' in s.lower(): s = s[:s.lower().index('alternative')]
if '[IMG0]' in s: s = s[:s.index('[IMG0]')]
s = '.'.join([s.strip() for s in s.split('.')[:2]])
if s[-1]!='.': s += '.'
return s.strip()
DEFAULT_IMAGE_TOKEN = '<image>'
DEFAULT_IMAGE_PATCH_TOKEN = '<im_patch>'
DEFAULT_IM_START_TOKEN = '<im_start>'
DEFAULT_IM_END_TOKEN = '<im_end>'
PATH_LLAVA = '/home/zbz5349/WorkSpace/aigeeks/ml-mgie/_ckpt/LLaVA-7B-v1'
tokenizer = transformers.AutoTokenizer.from_pretrained(PATH_LLAVA)
model = LlavaLlamaForCausalLM_.from_pretrained(PATH_LLAVA, low_cpu_mem_usage=True, torch_dtype=T.float16, use_cache=True).cuda()
image_processor = transformers.CLIPImageProcessor.from_pretrained(model.config.mm_vision_tower, torch_dtype=T.float16)
tokenizer.padding_side = 'left'
tokenizer.add_tokens(['[IMG0]', '[IMG1]', '[IMG2]', '[IMG3]', '[IMG4]', '[IMG5]', '[IMG6]', '[IMG7]'], special_tokens=True)
model.resize_token_embeddings(len(tokenizer))
ckpt = T.load('./_ckpt/mgie_7b/mllm.pt', map_location='cpu')
model.load_state_dict(ckpt, strict=False)
mm_use_im_start_end = getattr(model.config, 'mm_use_im_start_end', False)
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
if mm_use_im_start_end: tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
vision_tower = model.get_model().vision_tower[0]
vision_tower = transformers.CLIPVisionModel.from_pretrained(vision_tower.config._name_or_path, torch_dtype=T.float16, low_cpu_mem_usage=True).cuda()
model.get_model().vision_tower[0] = vision_tower
vision_config = vision_tower.config
vision_config.im_patch_token = tokenizer.convert_tokens_to_ids([DEFAULT_IMAGE_PATCH_TOKEN])[0]
vision_config.use_im_start_end = mm_use_im_start_end
if mm_use_im_start_end: vision_config.im_start_token, vision_config.im_end_token = tokenizer.convert_tokens_to_ids([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN])
image_token_len = (vision_config.image_size//vision_config.patch_size)**2
_ = model.eval()
EMB = ckpt['emb'].cuda()
with T.inference_mode(): NULL = model.edit_head(T.zeros(1, 8, 4096).half().to('cuda'), EMB)
print('NULL:', NULL.shape)
pipe = diffusers.StableDiffusionInstructPix2PixPipeline.from_pretrained('timbrooks/instruct-pix2pix', torch_dtype=T.float16, safety_checker=None).to('cuda')
pipe.set_progress_bar_config(disable=True)
pipe.unet.load_state_dict(T.load('./_ckpt/mgie_7b/unet.pt', map_location='cpu'))
SEED = 13331
# ins = ['make the frame red', 'turn the day into night', 'give him a beard', 'make cottage a mansion',
# 'remove yellow object from dogs paws', 'change the hair from red to blue', 'remove the text', 'increase the image contrast',
# 'remove the people in the background', 'please make this photo professional looking', 'darken the image, sharpen it', 'photoshop the girl out',
# 'make more brightness', 'take away the brown filter form the image', 'add more contrast to simulate more light', 'dark on rgb',
# 'make the face happy', 'change view as ocean', 'replace basketball with soccer ball', 'let the floor be made of wood']
data_path = '/home/zbz5349/WorkSpace/aigeeks/Qwen2.5-VL/magicbrush_dataset/genp2_4_multi.json'
save_image = '/home/zbz5349/WorkSpace/aigeeks/ml-mgie/all'
os.makedirs(save_image,exist_ok=True)
# 若有x个指令那么生成x(single) + x(mix) + 1(all)张图片
data = read_json(data_path)
for i in tqdm(range(100)):
img_path = data[i]["content"][0]["image"]
g = img_path
g = g.split('/')
txt = data[i]["content"][1]["text"]
save_img_path = f"{g[-1]}"
img = Image.open(img_path)
#img.save(os.path.join(save_image,f"ori_{i}{i}.png"))
#img, txt = Image.open('_input/%d.jpg'%(i)).convert('RGB'), ins[i]
img = image_processor.preprocess(img, return_tensors='pt')['pixel_values'][0]
txt = "what will this image be like if '%s'"%(txt)
txt = txt+'\n'+DEFAULT_IM_START_TOKEN+DEFAULT_IMAGE_PATCH_TOKEN*image_token_len+DEFAULT_IM_END_TOKEN
conv = conv_templates['vicuna_v1'].copy()
conv.append_message(conv.roles[0], txt), conv.append_message(conv.roles[1], None)
txt = conv.get_prompt()
txt = tokenizer(txt)
txt, mask = T.as_tensor(txt['input_ids']), T.as_tensor(txt['attention_mask'])
with T.inference_mode():
out = model.generate(txt.unsqueeze(dim=0).cuda(), images=img.half().unsqueeze(dim=0).cuda(), attention_mask=mask.unsqueeze(dim=0).cuda(),
do_sample=False, max_new_tokens=96, num_beams=1, no_repeat_ngram_size=3,
return_dict_in_generate=True, output_hidden_states=True)
out, hid = out['sequences'][0].tolist(), T.cat([x[-1] for x in out['hidden_states']], dim=1)[0]
p = min(out.index(32003)-1 if 32003 in out else len(hid)-9, len(hid)-9)
hid = hid[p:p+8]
out = remove_alter(tokenizer.decode(out))
emb = model.edit_head(hid.unsqueeze(dim=0), EMB)
res = pipe(image=Image.open(img_path).convert('RGB'), prompt_embeds=emb, negative_prompt_embeds=NULL, generator=T.Generator(device='cuda').manual_seed(SEED)).images[0]
save_img_path = os.path.join(save_image, save_img_path)
res.save(save_img_path) |