import json from openai import OpenAI import re def load_json(filename): with open(filename, 'r') as file: return json.load(file) client = OpenAI(api_key="sk-HZLqWTFgQKHUM0YN9d800981DbC34aEa90632493B9310360", base_url="https://vip.yi-zhan.top/v1") def get_result(prompt): response = client.chat.completions.create( model="gpt-4o-2024-05-13", messages=[ {"role": "system", "content": "You are a helpful assistant"}, {"role": "user", "content": [ {"type": "text", "text": prompt}, ] }, ], stream=False, temperature=0.8 ) return response.choices[0].message.content def create_prompt(question, llm_answer, gt_answer): template = """ ## Role You are a fair judge, comparing the LLM answer with the annotated answer, and evaluating whether the answer is accurate about the understanding of the surgical stage. ## Question ```json {question} ``` ## LLM answer ```json {llm_answer} ``` ## Annotated answer ```json {gt_answer} ``` ## Task For the given question, evaluate whether the LLM answer is consistent with the annotated answer. If yes, answer yes and give a reason. If no, answer no and give a reason. ## Constraints - Your answer should be divided into two parts: "Answer" and "Reason". "Answer" should be "Yes" or "No", indicating whether the large language model's prediction is consistent with the annotation information. "Reason" should provide the reason for your answer. - When evaluating the accuracy of the LLM's prediction, pay close attention to whether the model's answer is accurate about the understanding of the surgical phase, including whether the surgical stage is correctly identified and whether the operation suggestion given is appropriate. - Output format is json Dictionary, as shown below: "reason":reason, "answer":answer Take a deep breath and start answering step by step. """ prompt = template.format(question=question, llm_answer=llm_answer, gt_answer=gt_answer) return prompt def extract_answer(json_string): # 使用正则表达式匹配answer和reason answer_match = re.search(r'"answer":\s*"([^"]+)"', json_string) reason_match = re.search(r'"reason":\s*"([^"]+)"', json_string, re.DOTALL) # 提取匹配的内容 answer = answer_match.group(1) if answer_match else None reason = reason_match.group(1) if reason_match else None return answer, reason def main(llm_file, gt_file, out_file): llm_data = load_json(llm_file) gt_data = load_json(gt_file) QA_dict = {item["id"]:{} for item in llm_data} for item in llm_data: qid = item["id"] QA_dict[qid]["question"] = item["question"] QA_dict[qid]["llm_answer"] = item["answer"] for item in gt_data: qid = item["id"] # import ipdb # ipdb.set_trace() if qid in QA_dict.keys(): QA_dict[qid]["gt_answer"] = item["conversations"][1]["value"] compares = [] correct_ans = 0 for ix, (qid, item) in enumerate(QA_dict.items()): question = item["question"] llm_answer = item["llm_answer"] gt_answer = item["gt_answer"] prompt = create_prompt(question, llm_answer, gt_answer) try: compare = get_result(prompt=prompt) answer, reason = extract_answer(compare) compare_data = {"id": qid, "answer": answer, "reason":reason} compares.append(compare_data) with open(out_file, 'w') as f: json.dump(compares, f, indent=4) if answer.lower()== 'yes': correct_ans = correct_ans + 1 print(f"#correct \n answer:{answer},\n reason:{reason}") else: print(f"#wrong \n answer:{answer},\n reason:{reason}") except: print("break", item) continue print(f"[step {ix}, correct {correct_ans}, total {len(QA_dict)}, rate {correct_ans/len(QA_dict)}") compares.append(compare) return compares if __name__ == "__main__": # # intern fintuned: step 199, correct 42, total 200, rate 0.21 # llm_file = '/mnt1/wjl/InternLM-XComposer/output/finetune_0712_pwi+4dor_epoch2/internlm-xcomposer2-vl-7b/results/4dor_phase_instruct_0711_test_results.json' # 替换为你的 LLM 预测文件路径 # gt_file = '/mnt1/wjl/InternLM-XComposer/output/GT/4dor_phase_instruct_0711_test.json' # 替换为你的 GT 文件路径 # out_file = '/mnt1/wjl/InternLM-XComposer/output/finetune_0712_pwi+4dor_epoch2/internlm-xcomposer2-vl-7b/results_eval/4dor_phase_instruct_0711_test_compare.json' # 替换为你的 LLM 预测文件路径 # compares = main(llm_file, gt_file, out_file) # intern origin: # llm_file = '/mnt1/wjl/InternLM-XComposer/output/finetune_0712_pwi+4dor_epoch2/internlm-xcomposer2-vl-7b/origin_results/4dor_phase_instruct_0711_test_results.json' # 替换为你的 LLM 预测文件路径 # gt_file = '/mnt1/wjl/InternLM-XComposer/output/GT/4dor_phase_instruct_0711_test.json' # 替换为你的 GT 文件路径 # out_file = '/mnt1/wjl/InternLM-XComposer/output/finetune_0712_pwi+4dor_epoch2/internlm-xcomposer2-vl-7b/origin_results_eval/4dor_phase_instruct_0711_test_compare.json' # 替换为你的 LLM 预测文件路径 # compares = main(llm_file, gt_file, out_file) # llava 7b fintuned: [step 199, correct 111, total 200, rate 0.555] # llm_file = '/mnt1/wjl/LLaVA/checkpoints/llava-v1.5-7b-task-lora-2024-07-14-08/results/4dor_count_instruct_0711_test_results.json' # 替换为你的 LLM 预测文件路径 # gt_file = '/mnt1/wjl/InternLM-XComposer/output/GT/4dor_count_instruct_0711_test.json' # 替换为你的 GT 文件路径 # out_file = '/mnt1/wjl/LLaVA/checkpoints/llava-v1.5-7b-task-lora-2024-07-14-08/results_eval/4dor_count_instruct_0711_test_results.json' # 替换为你的 LLM 预测文件路径 # compares = main(llm_file, gt_file, out_file) # ## llava 7b origin: [step 199, correct 44, total 200, rate 0.22] # llm_file = '/mnt1/wjl/LLaVA/checkpoints/llava-v1.5-7b-orign-results/4dor_count_instruct_0711_test_results.json' # 替换为你的 LLM 预测文件路径 # gt_file = '/mnt1/wjl/InternLM-XComposer/output/GT/4dor_count_instruct_0711_test.json' # 替换为你的 GT 文件路径 # out_file = '/mnt1/wjl/LLaVA/checkpoints/llava-v1.5-7b-orign-results-eval/4dor_count_instruct_0711_test_results.json' # 替换为你的 LLM 预测文件路径 # compares = main(llm_file, gt_file, out_file) ## llava 13b fintuned: [step 199, correct 125, total 200, rate 0.625] # llm_file = '/mnt1/wjl/LLaVA/checkpoints/llava-v1.5-13b-task-lora-2024-07-14-07/results/4dor_count_instruct_0711_test_results.json' # 替换为你的 LLM 预测文件路径 # gt_file = '/mnt1/wjl/InternLM-XComposer/output/GT/4dor_count_instruct_0711_test.json' # 替换为你的 GT 文件路径 # out_file = '/mnt1/wjl/LLaVA/checkpoints/llava-v1.5-13b-task-lora-2024-07-14-07/results_eval/4dor_count_instruct_0711_test_results.json' # 替换为你的 LLM 预测文件路径 # compares = main(llm_file, gt_file, out_file) # ## llava 13b origin # llm_file = '/mnt1/wjl/LLaVA/checkpoints/llava-v1.5-13b-origin-results/4dor_count_instruct_0711_test_results.json' # 替换为你的 LLM 预测文件路径 # gt_file = '/mnt1/wjl/InternLM-XComposer/output/GT/4dor_count_instruct_0711_test.json' # 替换为你的 GT 文件路径 # out_file = '/mnt1/wjl/LLaVA/checkpoints/llava-v1.5-13b-origin-results-eval/4dor_count_instruct_0711_test_results.json' # 替换为你的 LLM 预测文件路径 # compares = main(llm_file, gt_file, out_file) llm_file = '/mnt1/lyc/llava_finetune/results_4dor/preds_phase.json' # 替换为你的 LLM 预测文件路径 gt_file = '/mnt1/lyc/llava_finetune/data_json/4dor_phase_instruct_0711_test.json' # 替换为你的 GT 文件路径 out_file = '/mnt1/lyc/llava_finetune/eval_output/4dor_phase_eval.json' # 替换为你的 LLM 预测文件路径 compares = main(llm_file, gt_file, out_file)