Upload generate_prompt.py with huggingface_hub
Browse files- generate_prompt.py +18 -8
generate_prompt.py
CHANGED
@@ -18,32 +18,42 @@ def write_json(file_path, data):
|
|
18 |
|
19 |
# default: Load the model on the available device(s)
|
20 |
print(torch.cuda.device_count())
|
21 |
-
model_path = "/home/zbz5349/WorkSpace/aigeeks/Qwen2.5-VL/ckpt_7B"
|
22 |
# model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
23 |
# model_path, torch_dtype="auto", device_map="auto"
|
24 |
# )
|
25 |
|
|
|
26 |
# We recommend enabling flash_attention_2 for better acceleration and memory saving, especially in multi-image and video scenarios.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
28 |
-
model_path,
|
29 |
torch_dtype=torch.bfloat16,
|
30 |
attn_implementation="flash_attention_2",
|
31 |
device_map="auto",
|
32 |
)
|
33 |
|
34 |
# default processor
|
35 |
-
processor = AutoProcessor.from_pretrained(model_path)
|
36 |
print(model.device)
|
37 |
|
38 |
|
39 |
|
40 |
|
41 |
-
data = read_json(
|
42 |
save_data = []
|
43 |
correct_num = 0
|
44 |
-
begin =
|
45 |
-
end =
|
46 |
-
batch_size =
|
47 |
for batch_idx in tqdm(range(begin, end, batch_size)):
|
48 |
batch = data[batch_idx:batch_idx + batch_size]
|
49 |
|
@@ -99,7 +109,7 @@ for batch_idx in tqdm(range(begin, end, batch_size)):
|
|
99 |
save_list[idx]['result'] = x
|
100 |
save_data.append(save_list[idx])
|
101 |
|
102 |
-
json_path =
|
103 |
write_json(json_path,save_data)
|
104 |
|
105 |
|
|
|
18 |
|
19 |
# default: Load the model on the available device(s)
|
20 |
print(torch.cuda.device_count())
|
21 |
+
#model_path = "/home/zbz5349/WorkSpace/aigeeks/Qwen2.5-VL/ckpt_7B"
|
22 |
# model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
23 |
# model_path, torch_dtype="auto", device_map="auto"
|
24 |
# )
|
25 |
|
26 |
+
|
27 |
# We recommend enabling flash_attention_2 for better acceleration and memory saving, especially in multi-image and video scenarios.
|
28 |
+
parser = argparse.ArgumentParser()
|
29 |
+
parser.add_argument("--model_path", type=str, default="/home/zbz5349/WorkSpace/aigeeks/Qwen2.5-VL/ckpt_7B")
|
30 |
+
parser.add_argument("--begin", type=int, default=0)
|
31 |
+
parser.add_argument("--end", type=int, default=4635)
|
32 |
+
parser.add_argument("--batch_size", type=int, default=3)
|
33 |
+
parser.add_argument("--data_path", type=str, default="/home/zbz5349/WorkSpace/aigeeks/Qwen2.5-VL/magicbrush_dataset/dataset.json")
|
34 |
+
parser.add_argument("--prompt_path", type=str, default="/home/zbz5349/WorkSpace/aigeeks/Qwen2.5-VL/magicbrush_dataset/gen.json")
|
35 |
+
|
36 |
+
args = parser.parse_args()
|
37 |
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
38 |
+
agrs.model_path,
|
39 |
torch_dtype=torch.bfloat16,
|
40 |
attn_implementation="flash_attention_2",
|
41 |
device_map="auto",
|
42 |
)
|
43 |
|
44 |
# default processor
|
45 |
+
processor = AutoProcessor.from_pretrained(args.model_path)
|
46 |
print(model.device)
|
47 |
|
48 |
|
49 |
|
50 |
|
51 |
+
data = read_json(args.data_path)
|
52 |
save_data = []
|
53 |
correct_num = 0
|
54 |
+
begin = args.begin
|
55 |
+
end = args.end
|
56 |
+
batch_size = args.batch_size
|
57 |
for batch_idx in tqdm(range(begin, end, batch_size)):
|
58 |
batch = data[batch_idx:batch_idx + batch_size]
|
59 |
|
|
|
109 |
save_list[idx]['result'] = x
|
110 |
save_data.append(save_list[idx])
|
111 |
|
112 |
+
json_path = args.prompt_path
|
113 |
write_json(json_path,save_data)
|
114 |
|
115 |
|