Research / demo_caption_elements.py
kelly0000's picture
Upload demo_caption_elements.py
e1050c3 verified
import os
import json
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
def read_json(file_path):
with open(file_path, 'r', encoding='utf-8') as file:
data = json.load(file)
return data
def write_json(file_path, data):
with open(file_path, 'w', encoding='utf-8') as file:
json.dump(data, file, ensure_ascii=False, indent=4)
import os
from openai import OpenAI
import pprint
import json
from llamaapi import LlamaAPI
# Initialize the SDK
llama = LlamaAPI("LL-SmrO4FiBWvkfaGskA4fe6qLSVa7Ob5B83jOojHNq8HkrukjRRG4Xt3CF1mLV9u6o")
os.environ["OPENAI_API_KEY"] = "sk-proj-Jmlrkk0HauWRhffybWOKT3BlbkFJIIuX6dFVCyVG7y6lGwsh"
# client = OpenAI()
# def reponse(sample):
# completion = client.chat.completions.create(
# model="gpt-3.5-turbo",
# # model="gpt-4",
# # model= "gpt-4-1106-vision-preview",
# messages=[
# {"role": "system", "content": ""},
# {"role": "user", "content": sample}
# ]
# )
# # print(completion.choices[0].message.content)
# return completion.choices[0].message.content
# return completion
from chat import MiniCPMVChat, img2base64
import torch
import json
from PIL import Image
torch.manual_seed(0)
chat_model = MiniCPMVChat('/code/ICLR_2024/Model/MiniCPM-Llama3-V-2_5')
image_path = '/code/ICLR_2024/SeeClick/output_image_27.png'
# image = Image.open(image_path)
# image.show()
qs = """
List all the application name and location in the image that can be interacted with, the result shoudl be like a list
"""
im_64 = img2base64(image_path)
msgs = [{"role": "user", "content": qs}]
inputs = {"image": im_64, "question": json.dumps(msgs)}
answer = chat_model.chat(inputs)
data = read_json("/code/ICLR_2024/Auto-GUI/dataset/blip/single_blip_train_llava_10000_caption_elements_llama3_70b.json")
retrival_dict = {}
for index, i in enumerate(data):
retrival_dict[i['image']] = index
path = '/code/ICLR_2024/Auto-GUI/dataset/'
image_id = [ x['image'].split('/')[2].split('.')[0] for x in data]
all_pair_id = {}
all_pair_key = []
for i in image_id:
key = i.split('_')[0]
all_pair_id[key] = []
all_pair_key.append(key)
for i in image_id:
key = i.split('_')[0]
value = i.split('_')[1]
all_pair_id[key].append(value)
all_pair_key = list(set(all_pair_key))
path2 = 'blip/single_texts_splits/'
from tqdm import tqdm
for i in tqdm(all_pair_key[770:]):
num_list = all_pair_id[i]
for j in num_list:
retival_path = path2 + i + '_' + j + '.png'
new_path = path + path2 + i + '_' + j + '.png'
ids = retrival_dict[retival_path]
image_path = path + data[ids]['image']
caption = data[ids]['caption']
Previous = data[ids]['conversations'][0]['value']
Previous = Previous.lower()
task = Previous.split('goal')[1]
Demo_prompt_step1 = """
List all the application name and location in the image that can be interacted with, the result shoudl be like a list
"""
im_64 = img2base64(image_path)
msgs = [{"role": "user", "content": Demo_prompt_step1}]
inputs = {"image": im_64, "question": json.dumps(msgs)}
answer = chat_model.chat(inputs)
data[ids]['icon_list_raw'] = answer
pprint.pprint(answer)
prompt = """ ##### refine it to a list, list name must be elements , just like:
elements = [
"Newegg",
"Newegg CEO",
"Newegg customer service",
"Newegg founder",
"Newegg promo code",
"Newegg return policy",
"Newegg revenue",
"Newegg military discounts"]
Answer the python list only!
##### """
import time
time.sleep(2)
api_request_json = {
"model": "llama3-70b",
"messages": [
{"role": "system", "content": "You are a assistant that will handle the corresponding text formatting for me."},
{"role": "user", "content": answer + prompt},
],
"max_tokens": 1024
}
try:
# new_answer = reponse(answer + prompt) # GPT4 Version
response = llama.run(api_request_json)
new_answer = response.json()['choices'][0]['message']['content']
print('======================================================')
pprint.pprint(new_answer)
print('======================================================')
except Exception as e:
print(f"Error in LLAMA API Generation : {e}")
import time
time.sleep(30)
continue
try:
exec(new_answer)
data[ids]['icon_list'] = elements
except Exception as e:
print(f"Error in setting data[ids]['icon_list']: {e}")
continue
write_json('/code/ICLR_2024/Auto-GUI/dataset/blip/single_blip_train_llava_10000_caption_elements_llama3_70b.json',data)