File size: 4,598 Bytes
bc41164
 
 
 
 
 
 
 
585b226
bc41164
 
 
 
 
 
 
 
 
 
 
 
e299762
bc41164
 
 
 
e299762
bc41164
e299762
 
 
 
dd871ee
e299762
 
 
 
bc41164
4fcbf82
bc41164
 
 
 
 
 
e299762
bc41164
 
 
 
 
e299762
bc41164
 
e299762
 
 
742d853
 
 
bc41164
dd871ee
083393f
bc41164
 
 
 
 
 
 
656cc39
bc41164
 
 
 
644f8ad
bc41164
f104f8a
 
bc41164
 
656cc39
f7cf814
 
5c91f80
 
bc41164
 
 
3ac3eff
23be32a
656cc39
23be32a
656cc39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5c91f80
656cc39
62cd74f
742d853
 
4df164a
 
bc41164
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import os

import torch
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info
import json
from tqdm import tqdm
import os
import argparse

def read_json(file_path): 
    with open(file_path, 'r', encoding='utf-8') as file:
        data = json.load(file)
    return data

def write_json(file_path, data):
    with open(file_path, 'w', encoding='utf-8') as file:
        json.dump(data, file, ensure_ascii=False, indent=4)

# default: Load the model on the available device(s)
print(torch.cuda.device_count())
#model_path = "/home/zbz5349/WorkSpace/aigeeks/Qwen2.5-VL/ckpt_7B"
# model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
#     model_path, torch_dtype="auto", device_map="auto"
# )


# We recommend enabling flash_attention_2 for better acceleration and memory saving, especially in multi-image and video scenarios.
parser = argparse.ArgumentParser()
parser.add_argument("--model_path", type=str, default="/home/zbz5349/WorkSpace/aigeeks/Qwen2.5-VL/ckpt_7B")
parser.add_argument("--begin", type=int, default=0)
parser.add_argument("--end", type=int, default=4635)
parser.add_argument("--batch_size", type=int, default=16)
parser.add_argument("--data_path", type=str, default="/home/zbz5349/WorkSpace/aigeeks/Qwen2.5-VL/magicbrush_dataset/dataset.json")
parser.add_argument("--prompt_path", type=str, default="/home/zbz5349/WorkSpace/aigeeks/Qwen2.5-VL/magicbrush_dataset/gen.json")

args = parser.parse_args()
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
    args.model_path,
    torch_dtype=torch.bfloat16,
    attn_implementation="flash_attention_2",
    device_map="auto",
)

# default processor
processor = AutoProcessor.from_pretrained(args.model_path)
print(model.device)




data = read_json(args.data_path)
save_data = []
correct_num = 0
begin = args.begin
end = args.end
batch_size = args.batch_size
json_path = args.prompt_path


for batch_idx in tqdm(range(begin, end, batch_size)):
    batch = data[batch_idx:min(batch_idx+batch_size, end)] 
    print(len(batch))
    image_list = []
    input_text_list = []
    data_list = []
    save_list = []
    sd_ans = []
    # while True:
    for idx, i in enumerate(batch):
        save_ =  {
            "role": "user",
            "content": [
                {
                    "type": "image",
                    "image": "",
                },
            {"type": "text",
             "text": "Please help me write a prompt for image editing on this picture. The requirements are as follows: complex editing instructions should include two to five simple editing instructions involving spatial relationships (simple editing instructions such as ADD: add an object to the left of a certain object, DELETE: delete a certain object, MODIFY: change a certain object into another object). We hope that the editing instructions can have simple reasoning and can also include some abstract concept-based editing (such as making the atmosphere more romantic, or making the diet healthier, or making the boy more handsome and the girl more beautiful, etc.). Please give me clear editing instructions and also consider whether such editing instructions are reasonable."},
            ],
            "result":""
        }
        #idx_real = batch_idx * batch_size + idx
        messages = batch[idx]
        save_['content'][0]['image'] = messages['content'][0]['image']
        save_['content'][1]['text'] = messages['content'][1]['text']

        data_list.append(messages)
        save_list.append(save_)
    #print(len(data_list))
    text = processor.apply_chat_template(data_list, tokenize=False, add_generation_prompt=True)
    #print(len(text))
    image_inputs, video_inputs = process_vision_info(data_list)
    inputs = processor(
        text=[text],
        images=image_inputs,
        videos=video_inputs,
        padding=True,
        return_tensors="pt",
    )
    inputs = inputs.to(model.device)

    # Inference: Generation of the output
    generated_ids = model.generate(**inputs, max_new_tokens=128)
    #print(generated_ids.shape)
    generated_ids_trimmed = [
        out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
    ]
    output_text = processor.batch_decode(
        generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
    )
    #print(output_text)
    save_['result'] = output_text
    save_data.append(save_)

    if batch_idx % 4 ==0:
        write_json(json_path,save_data)
        print(len(save_data))

write_json(json_path,save_data)