|
import json |
|
from openai import OpenAI |
|
import re |
|
|
|
|
|
def load_json(filename): |
|
with open(filename, 'r') as file: |
|
return json.load(file) |
|
|
|
|
|
client = OpenAI(api_key="sk-HZLqWTFgQKHUM0YN9d800981DbC34aEa90632493B9310360", |
|
base_url="https://vip.yi-zhan.top/v1") |
|
|
|
def get_result(prompt): |
|
response = client.chat.completions.create( |
|
model="gpt-4o-2024-05-13", |
|
messages=[ |
|
{"role": "system", "content": "You are a helpful assistant"}, |
|
{"role": "user", "content": |
|
[ |
|
{"type": "text", "text": prompt}, |
|
] |
|
}, |
|
], |
|
stream=False, |
|
temperature=0.8 |
|
) |
|
return response.choices[0].message.content |
|
|
|
|
|
def create_prompt(question, llm_answer, gt_answer): |
|
|
|
template = """ |
|
## Role |
|
You are a fair judge, comparing the LLM answer with the annotated answer, and evaluating whether the answer is accurate about the understanding of the surgical stage. |
|
|
|
## Question |
|
```json |
|
{question} |
|
``` |
|
|
|
## LLM answer |
|
```json |
|
{llm_answer} |
|
``` |
|
|
|
## Annotated answer |
|
```json |
|
{gt_answer} |
|
``` |
|
|
|
## Task |
|
For the given question, evaluate whether the LLM answer is consistent with the annotated answer. If yes, answer yes and give a reason. If no, answer no and give a reason. |
|
|
|
## Constraints |
|
- Your answer should be divided into two parts: "Answer" and "Reason". "Answer" should be "Yes" or "No", indicating whether the large language model's prediction is consistent with the annotation information. "Reason" should provide the reason for your answer. |
|
- When evaluating the accuracy of the LLM's prediction, pay close attention to whether the model's answer is accurate about the understanding of the surgical phase, including whether the surgical stage is correctly identified and whether the operation suggestion given is appropriate. |
|
- Output format is json Dictionary, as shown below: |
|
"reason":reason, |
|
"answer":answer |
|
|
|
Take a deep breath and start answering step by step. |
|
""" |
|
|
|
|
|
prompt = template.format(question=question, |
|
llm_answer=llm_answer, |
|
gt_answer=gt_answer) |
|
return prompt |
|
|
|
|
|
def extract_answer(json_string): |
|
|
|
answer_match = re.search(r'"answer":\s*"([^"]+)"', json_string) |
|
reason_match = re.search(r'"reason":\s*"([^"]+)"', json_string, re.DOTALL) |
|
|
|
|
|
answer = answer_match.group(1) if answer_match else None |
|
reason = reason_match.group(1) if reason_match else None |
|
|
|
return answer, reason |
|
|
|
|
|
def main(llm_file, gt_file, out_file): |
|
llm_data = load_json(llm_file) |
|
gt_data = load_json(gt_file) |
|
|
|
QA_dict = {item["id"]:{} for item in llm_data} |
|
|
|
for item in llm_data: |
|
qid = item["id"] |
|
QA_dict[qid]["question"] = item["question"] |
|
QA_dict[qid]["llm_answer"] = item["answer"] |
|
|
|
for item in gt_data: |
|
qid = item["id"] |
|
|
|
|
|
if qid in QA_dict.keys(): |
|
QA_dict[qid]["gt_answer"] = item["conversations"][1]["value"] |
|
|
|
compares = [] |
|
correct_ans = 0 |
|
for ix, (qid, item) in enumerate(QA_dict.items()): |
|
question = item["question"] |
|
llm_answer = item["llm_answer"] |
|
gt_answer = item["gt_answer"] |
|
prompt = create_prompt(question, llm_answer, gt_answer) |
|
|
|
try: |
|
compare = get_result(prompt=prompt) |
|
|
|
answer, reason = extract_answer(compare) |
|
compare_data = {"id": qid, "answer": answer, "reason":reason} |
|
compares.append(compare_data) |
|
with open(out_file, 'w') as f: |
|
json.dump(compares, f, indent=4) |
|
|
|
if answer.lower()== 'yes': |
|
correct_ans = correct_ans + 1 |
|
print(f"#correct \n answer:{answer},\n reason:{reason}") |
|
else: |
|
print(f"#wrong \n answer:{answer},\n reason:{reason}") |
|
|
|
except: |
|
print("break", item) |
|
continue |
|
print(f"[step {ix}, correct {correct_ans}, total {len(QA_dict)}, rate {correct_ans/len(QA_dict)}") |
|
compares.append(compare) |
|
|
|
return compares |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
llm_file = '/mnt1/lyc/llava_finetune/results_4dor/preds_phase.json' |
|
gt_file = '/mnt1/lyc/llava_finetune/data_json/4dor_phase_instruct_0711_test.json' |
|
out_file = '/mnt1/lyc/llava_finetune/eval_output/4dor_phase_eval.json' |
|
compares = main(llm_file, gt_file, out_file) |
|
|
|
|
|
|