kelly0000 commited on
Commit
e1050c3
·
verified ·
1 Parent(s): 25157a7

Upload demo_caption_elements.py

Browse files
Files changed (1) hide show
  1. demo_caption_elements.py +176 -0
demo_caption_elements.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+
4
+ os.environ["CUDA_VISIBLE_DEVICES"] = "1"
5
+
6
+
7
+ def read_json(file_path):
8
+ with open(file_path, 'r', encoding='utf-8') as file:
9
+ data = json.load(file)
10
+ return data
11
+
12
+ def write_json(file_path, data):
13
+ with open(file_path, 'w', encoding='utf-8') as file:
14
+ json.dump(data, file, ensure_ascii=False, indent=4)
15
+
16
+
17
+ import os
18
+ from openai import OpenAI
19
+ import pprint
20
+ import json
21
+ from llamaapi import LlamaAPI
22
+
23
+ # Initialize the SDK
24
+ llama = LlamaAPI("LL-SmrO4FiBWvkfaGskA4fe6qLSVa7Ob5B83jOojHNq8HkrukjRRG4Xt3CF1mLV9u6o")
25
+ os.environ["OPENAI_API_KEY"] = "sk-proj-Jmlrkk0HauWRhffybWOKT3BlbkFJIIuX6dFVCyVG7y6lGwsh"
26
+
27
+
28
+ # client = OpenAI()
29
+ # def reponse(sample):
30
+ # completion = client.chat.completions.create(
31
+ # model="gpt-3.5-turbo",
32
+ # # model="gpt-4",
33
+ # # model= "gpt-4-1106-vision-preview",
34
+ # messages=[
35
+ # {"role": "system", "content": ""},
36
+ # {"role": "user", "content": sample}
37
+ # ]
38
+ # )
39
+
40
+ # # print(completion.choices[0].message.content)
41
+ # return completion.choices[0].message.content
42
+ # return completion
43
+
44
+
45
+
46
+ from chat import MiniCPMVChat, img2base64
47
+ import torch
48
+ import json
49
+ from PIL import Image
50
+
51
+
52
+ torch.manual_seed(0)
53
+ chat_model = MiniCPMVChat('/code/ICLR_2024/Model/MiniCPM-Llama3-V-2_5')
54
+
55
+
56
+ image_path = '/code/ICLR_2024/SeeClick/output_image_27.png'
57
+ # image = Image.open(image_path)
58
+ # image.show()
59
+
60
+ qs = """
61
+ List all the application name and location in the image that can be interacted with, the result shoudl be like a list
62
+ """
63
+
64
+ im_64 = img2base64(image_path)
65
+ msgs = [{"role": "user", "content": qs}]
66
+ inputs = {"image": im_64, "question": json.dumps(msgs)}
67
+ answer = chat_model.chat(inputs)
68
+
69
+ data = read_json("/code/ICLR_2024/Auto-GUI/dataset/blip/single_blip_train_llava_10000_caption_elements_llama3_70b.json")
70
+
71
+
72
+ retrival_dict = {}
73
+ for index, i in enumerate(data):
74
+ retrival_dict[i['image']] = index
75
+
76
+ path = '/code/ICLR_2024/Auto-GUI/dataset/'
77
+ image_id = [ x['image'].split('/')[2].split('.')[0] for x in data]
78
+
79
+ all_pair_id = {}
80
+ all_pair_key = []
81
+ for i in image_id:
82
+ key = i.split('_')[0]
83
+ all_pair_id[key] = []
84
+ all_pair_key.append(key)
85
+
86
+ for i in image_id:
87
+ key = i.split('_')[0]
88
+ value = i.split('_')[1]
89
+ all_pair_id[key].append(value)
90
+
91
+ all_pair_key = list(set(all_pair_key))
92
+ path2 = 'blip/single_texts_splits/'
93
+
94
+
95
+ from tqdm import tqdm
96
+ for i in tqdm(all_pair_key[770:]):
97
+
98
+ num_list = all_pair_id[i]
99
+ for j in num_list:
100
+
101
+ retival_path = path2 + i + '_' + j + '.png'
102
+ new_path = path + path2 + i + '_' + j + '.png'
103
+ ids = retrival_dict[retival_path]
104
+
105
+ image_path = path + data[ids]['image']
106
+ caption = data[ids]['caption']
107
+ Previous = data[ids]['conversations'][0]['value']
108
+
109
+ Previous = Previous.lower()
110
+ task = Previous.split('goal')[1]
111
+
112
+ Demo_prompt_step1 = """
113
+ List all the application name and location in the image that can be interacted with, the result shoudl be like a list
114
+ """
115
+
116
+ im_64 = img2base64(image_path)
117
+ msgs = [{"role": "user", "content": Demo_prompt_step1}]
118
+ inputs = {"image": im_64, "question": json.dumps(msgs)}
119
+ answer = chat_model.chat(inputs)
120
+
121
+ data[ids]['icon_list_raw'] = answer
122
+ pprint.pprint(answer)
123
+
124
+ prompt = """ ##### refine it to a list, list name must be elements , just like:
125
+ elements = [
126
+ "Newegg",
127
+ "Newegg CEO",
128
+ "Newegg customer service",
129
+ "Newegg founder",
130
+ "Newegg promo code",
131
+ "Newegg return policy",
132
+ "Newegg revenue",
133
+ "Newegg military discounts"]
134
+
135
+ Answer the python list only!
136
+ ##### """
137
+
138
+ import time
139
+ time.sleep(2)
140
+
141
+ api_request_json = {
142
+ "model": "llama3-70b",
143
+ "messages": [
144
+ {"role": "system", "content": "You are a assistant that will handle the corresponding text formatting for me."},
145
+ {"role": "user", "content": answer + prompt},
146
+
147
+ ],
148
+ "max_tokens": 1024
149
+
150
+ }
151
+
152
+ try:
153
+ # new_answer = reponse(answer + prompt) # GPT4 Version
154
+ response = llama.run(api_request_json)
155
+ new_answer = response.json()['choices'][0]['message']['content']
156
+ print('======================================================')
157
+ pprint.pprint(new_answer)
158
+ print('======================================================')
159
+ except Exception as e:
160
+ print(f"Error in LLAMA API Generation : {e}")
161
+ import time
162
+ time.sleep(30)
163
+ continue
164
+
165
+ try:
166
+ exec(new_answer)
167
+ data[ids]['icon_list'] = elements
168
+ except Exception as e:
169
+ print(f"Error in setting data[ids]['icon_list']: {e}")
170
+ continue
171
+
172
+
173
+
174
+ write_json('/code/ICLR_2024/Auto-GUI/dataset/blip/single_blip_train_llava_10000_caption_elements_llama3_70b.json',data)
175
+
176
+