import os import json import torch from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig, CLIPImageProcessor def rephrase_data_json(json_path): output_json_dict = [] with open(json_path) as f: json_dict = json.load(f) for i in range(len(json_dict)): data_dict = json_dict[i] if isinstance(data_dict['image'], list): data_dict['image'] = data_dict['image'][0] if data_dict['conversations'][0]['value'] == '' or \ data_dict['conversations'][1]['value'] == '': continue output_json_dict.append(data_dict) with open(json_path.replace('.json', '_rephrase.json'), 'w') as f: json.dump(output_json_dict, f, indent= 4) return def test_llama3(): model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct", torch_dtype=torch.float16, low_cpu_mem_usage=True) model = CLIPImageProcessor.from_pretrained("openai/clip-vit-large-patch14-336", low_cpu_mem_usage=True) def create_test_data(): all_json = os.listdir('./data_json') data_json = [] for i in all_json: if i.endswith('_test.json'): with open(os.path.join('./data_json', i)) as f: data_json.extend(json.load(f)[:200]) with open('./data_json/instruct_sample_18430_0713_rephrase_test.json', 'w') as f: json.dump(data_json, f, indent = 4) def create_eval_dataset(gt_data_dir, pred_data_json, output_dir): with open(pred_data_json, 'r') as f: pred_data = json.load(pred_data_json) gt_data_list = {} pred_data_list = {} for i in pred_data: task = i['task'] if task not in pred_data_list.keys(): pred_data_list[task] = [] cur_data = i cur_data['answer'] = cur_data['pred_output'] cur_data.pop('source') cur_data.pop('task') cur_data.pop('gt_output') if task not in gt_data_list.keys(): with open(os.path.join(gt_data_dir, f'4dor_{task}_instruct_0711_test.json'), 'r') as f: gt_data_list[task] = json.load(f)[:200] for j in gt_data_list[task]: if j['id'] == cur_data['id']: cur_data['question'] = j['conversations'][0]['value'] if cur_data['question'].start_with(''): cur_data['question'] = cur_data['question'][8:] pred_data_list[task].append(cur_data) output_file = os.path.join(output_dir, f'{task}_pred.json') with open(output_file, 'w') as f: json.dump(pred_data_list[task], f, indent=4) def create_eval_dataset_from_output(pred_data_json, output_dir): with open(pred_data_json) as f: pred_data = json.load(f) task_list = {} for i in pred_data: task = i['task'] if task not in task_list.keys(): task_list[task] = [] cur_data = i cur_data['answer'] = cur_data['pred_output'] cur_data.pop('source') cur_data.pop('task') cur_data.pop('gt_output') task_list[task].append(cur_data) def debug(): output_tasks = {} output_list = [] for j in range(8): with open(f'./temp_{j}.json', 'r') as f: temp_output = json.load(f) for t in temp_output: if t['task'] not in output_tasks.keys(): output_tasks[t['task']] = [] output_tasks[t['task']].append(t) output_list.append(t) os.remove(f'./temp_{j}.json') with open(os.path.join('.', f'preds.json'), 'w') as f: json.dump(output_list, f, indent = 4) for k in output_tasks.keys(): with open(os.path.join('.' f'preds_{k}.json'), 'w') as f: json.dump(output_tasks[k], f, indent = 4) def count_correct(json_file = '/mnt1/lyc/llava_finetune/eval_output/pwiseg_count_eval_llama3_llava.json'): with open(json_file) as f: data = json.load(f) num_corr = 0 for d in data: if type(d) == str: continue if d['answer'] == 'yes': num_corr += 1 print(f'Rate: {num_corr / 200}') print(f'Num: {num_corr}') def process_relationship_data_v0(relationship_json_file = '/mnt1/wjl/LLaVA/data/4dor_infos_0702.json'): with open(relationship_json_file) as f: relationship_json_data = json.load(f) if __name__ == '__main__': # json_path = '/mnt1/lyc/LLaVA-NeXT/instruct_sample_18430_0713.json' # rephrase_data_json(json_path) # test_llama3() # create_test_data() # debug() count_correct()