Wendy-Fly commited on
Commit
e299762
·
verified ·
1 Parent(s): c1fbac9

Upload generate_prompt.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. generate_prompt.py +18 -8
generate_prompt.py CHANGED
@@ -18,32 +18,42 @@ def write_json(file_path, data):
18
 
19
  # default: Load the model on the available device(s)
20
  print(torch.cuda.device_count())
21
- model_path = "/home/zbz5349/WorkSpace/aigeeks/Qwen2.5-VL/ckpt_7B"
22
  # model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
23
  # model_path, torch_dtype="auto", device_map="auto"
24
  # )
25
 
 
26
  # We recommend enabling flash_attention_2 for better acceleration and memory saving, especially in multi-image and video scenarios.
 
 
 
 
 
 
 
 
 
27
  model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
28
- model_path,
29
  torch_dtype=torch.bfloat16,
30
  attn_implementation="flash_attention_2",
31
  device_map="auto",
32
  )
33
 
34
  # default processor
35
- processor = AutoProcessor.from_pretrained(model_path)
36
  print(model.device)
37
 
38
 
39
 
40
 
41
- data = read_json('/home/zbz5349/WorkSpace/aigeeks/Qwen2.5-VL/magicbrush_dataset/dataset.json')
42
  save_data = []
43
  correct_num = 0
44
- begin = 0
45
- end = len(data)
46
- batch_size = 1
47
  for batch_idx in tqdm(range(begin, end, batch_size)):
48
  batch = data[batch_idx:batch_idx + batch_size]
49
 
@@ -99,7 +109,7 @@ for batch_idx in tqdm(range(begin, end, batch_size)):
99
  save_list[idx]['result'] = x
100
  save_data.append(save_list[idx])
101
 
102
- json_path = "image_path = '/home/zbz5349/WorkSpace/aigeeks/Qwen2.5-VL/magicbrush_dataset/gen.json'"
103
  write_json(json_path,save_data)
104
 
105
 
 
18
 
19
  # default: Load the model on the available device(s)
20
  print(torch.cuda.device_count())
21
+ #model_path = "/home/zbz5349/WorkSpace/aigeeks/Qwen2.5-VL/ckpt_7B"
22
  # model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
23
  # model_path, torch_dtype="auto", device_map="auto"
24
  # )
25
 
26
+
27
  # We recommend enabling flash_attention_2 for better acceleration and memory saving, especially in multi-image and video scenarios.
28
+ parser = argparse.ArgumentParser()
29
+ parser.add_argument("--model_path", type=str, default="/home/zbz5349/WorkSpace/aigeeks/Qwen2.5-VL/ckpt_7B")
30
+ parser.add_argument("--begin", type=int, default=0)
31
+ parser.add_argument("--end", type=int, default=4635)
32
+ parser.add_argument("--batch_size", type=int, default=3)
33
+ parser.add_argument("--data_path", type=str, default="/home/zbz5349/WorkSpace/aigeeks/Qwen2.5-VL/magicbrush_dataset/dataset.json")
34
+ parser.add_argument("--prompt_path", type=str, default="/home/zbz5349/WorkSpace/aigeeks/Qwen2.5-VL/magicbrush_dataset/gen.json")
35
+
36
+ args = parser.parse_args()
37
  model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
38
+ agrs.model_path,
39
  torch_dtype=torch.bfloat16,
40
  attn_implementation="flash_attention_2",
41
  device_map="auto",
42
  )
43
 
44
  # default processor
45
+ processor = AutoProcessor.from_pretrained(args.model_path)
46
  print(model.device)
47
 
48
 
49
 
50
 
51
+ data = read_json(args.data_path)
52
  save_data = []
53
  correct_num = 0
54
+ begin = args.begin
55
+ end = args.end
56
+ batch_size = args.batch_size
57
  for batch_idx in tqdm(range(begin, end, batch_size)):
58
  batch = data[batch_idx:batch_idx + batch_size]
59
 
 
109
  save_list[idx]['result'] = x
110
  save_data.append(save_list[idx])
111
 
112
+ json_path = args.prompt_path
113
  write_json(json_path,save_data)
114
 
115