kelly0000 commited on
Commit
a5aab08
·
verified ·
1 Parent(s): c5583e4

Upload caption_aitw_v2.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. caption_aitw_v2.py +95 -0
caption_aitw_v2.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import json
4
+ from PIL import Image
5
+ import pprint
6
+ from tqdm import tqdm
7
+ from multiprocessing import Pool, cpu_count
8
+
9
+
10
+ from chat import MiniCPMVChat, img2base64
11
+
12
+
13
+
14
+ def read_json(file_path):
15
+ with open(file_path, 'r', encoding='utf-8') as file:
16
+ data = json.load(file)
17
+ return data
18
+
19
+ def write_json(file_path, data):
20
+ with open(file_path, 'w', encoding='utf-8') as file:
21
+ json.dump(data, file, ensure_ascii=False, indent=4)
22
+
23
+ def preprocess_data(data, path_base):
24
+ """将图像路径替换为 base64 编码,减少重复 I/O。"""
25
+ for item in data:
26
+ img_path = os.path.join(path_base, item['image'])
27
+ item['image_base64'] = img2base64(img_path)
28
+ return data
29
+
30
+
31
+
32
+ def chat_minicpm_application(image_path):
33
+
34
+ qs = """
35
+ List the names and locations of all interactive applications in the image, as well as their functionality and potential applications.
36
+ """
37
+ # qs = f'''{context}. The green frame in the picture represents the situation of clicking, need to explain why click in the corresponding area.
38
+ # '''
39
+ im_64 = img2base64(image_path)
40
+ msgs = [{"role": "user", "content": qs}]
41
+ inputs = {"image": im_64, "question": json.dumps(msgs)}
42
+ answer = chat_model.chat(inputs)
43
+ return answer
44
+
45
+
46
+ def chat_minicpm_content(image_path):
47
+
48
+ qs = """
49
+ Describe the content of this image.
50
+ """
51
+
52
+ im_64 = img2base64(image_path)
53
+ msgs = [{"role": "user", "content": qs}]
54
+ inputs = {"image": im_64, "question": json.dumps(msgs)}
55
+ answer = chat_model.chat(inputs)
56
+ return answer
57
+
58
+ def chat_minicpm_mind(image_path):
59
+
60
+ qs = """
61
+ The green frame in the picture represents the situation of clicking, need to explain why click in the corresponding area. Answer template: The green box ....
62
+ """
63
+
64
+ im_64 = img2base64(image_path)
65
+ msgs = [{"role": "user", "content": qs}]
66
+ inputs = {"image": im_64, "question": json.dumps(msgs)}
67
+ answer = chat_model.chat(inputs)
68
+ return answer
69
+
70
+
71
+
72
+ torch.manual_seed(0)
73
+ chat_model = MiniCPMVChat('/code/Model/MiniCPM-Llama3-V-2_5')
74
+ path_base = '/code/Auto-GUI/dataset/'
75
+
76
+
77
+ data = read_json("/code/Auto-GUI/dataset/mind/general_blip_train_llava_coco.json")
78
+ data = [line for line in data if line['action_type'] == '#DUAL_POINT#'][17370:]
79
+
80
+
81
+
82
+ for idx, i in enumerate(tqdm(data), 1): # 从1开始计数,便于后续计数判断
83
+ img_path = path_base + i['image']
84
+ # context = data[idx]['conversations'][0]['value']
85
+ i['application'] = chat_minicpm_application(img_path)
86
+ i['content'] = chat_minicpm_content(img_path)
87
+ i['mind'] = chat_minicpm_mind(img_path)
88
+
89
+ # 每100次保存一次
90
+ if idx % 100 == 0:
91
+ write_json('/code/MiniCPM-V/general_blip_train_llava_coco_caption_mind2.json', data)
92
+
93
+ # 最后保存一次,确保未满100的剩余数据也能保存
94
+ write_json('/code/MiniCPM-V/general_blip_train_llava_coco_caption_mind2.json', data)
95
+