llava_finetune / eval_scripts /caption_eval.py
lyclyc52's picture
Update: integrate llama3 into finetuning code
157f5b2
raw
history blame
3.39 kB
import json
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
from pycocoevalcap.rouge.rouge import Rouge
from pycocoevalcap.cider.cider import Cider
import itertools
def load_json(filename):
with open(filename, 'r') as file:
return json.load(file)
def extract_answers(llm_data, gt_data):
llm_answers = {item['id']: item['answer'] for item in llm_data}
gt_answers = {item['id']: [conv['value'] for conv in item['conversations'] if conv['from'] == 'gpt'][0] for item in gt_data}
return llm_answers, gt_answers
def compute_bleu_scores(reference, hypothesis):
smooth_fn = SmoothingFunction().method1
weights = [
(1, 0, 0, 0), # BLEU-1
(0.5, 0.5, 0, 0), # BLEU-2
(0.33, 0.33, 0.33, 0), # BLEU-3
(0.25, 0.25, 0.25, 0.25) # BLEU-4
]
scores = [sentence_bleu([reference], hypothesis, weights=w, smoothing_function=smooth_fn) for w in weights]
return scores
def compute_rouge_scores(references, hypotheses):
rouge = Rouge()
scores, _ = rouge.compute_score(references, hypotheses)
return scores
def compute_cider_scores(references, hypotheses):
cider = Cider()
scores, _ = cider.compute_score(references, hypotheses)
return scores
def main(llm_file, gt_file):
llm_data = load_json(llm_file)
gt_data = load_json(gt_file)
llm_answers, gt_answers = extract_answers(llm_data, gt_data)
bleu_scores = {i: [] for i in range(4)}
references = {}
hypotheses = {}
for id in llm_answers:
if id in gt_answers:
hypothesis = llm_answers[id].split()
reference = gt_answers[id].split()
bleu = compute_bleu_scores(reference, hypothesis)
for i in range(4):
bleu_scores[i].append(bleu[i])
references[id] = [gt_answers[id]]
hypotheses[id] = [llm_answers[id]]
else:
print(f"ID {id} not found in ground truth data.")
rouge_scores = compute_rouge_scores(references, hypotheses)
cider_scores = compute_cider_scores(references, hypotheses)
avg_bleu_scores = [sum(scores) / len(scores) for scores in bleu_scores.values()]
# print(rouge_scores)
# print(cider_scores)
# avg_rouge_score = sum(rouge_scores) / len(rouge_scores)
# avg_cider_score = sum(cider_scores) / len(cider_scores)
print(f"Average BLEU-1: {avg_bleu_scores[0]:.4f}")
print(f"Average BLEU-2: {avg_bleu_scores[1]:.4f}")
print(f"Average BLEU-3: {avg_bleu_scores[2]:.4f}")
print(f"Average BLEU-4: {avg_bleu_scores[3]:.4f}")
print(f"Average ROUGE-L: {rouge_scores:.4f}")
print(f"Average CIDEr: {cider_scores:.4f}")
def test():
llm_type = ['finetune', 'ori']
data_gt = {'pwiseg': '/mnt1/wjl/InternLM-XComposer/output/GT/pwiseg_caption_instruct_0712_test.json',
'4dor': '/mnt1/lyc/llava_finetune/data_json/4dor_caption_instruct_0711_test.json'
}
data_type = ['pwiseg', '4dor']
for dt, lt in itertools.product(data_gt, llm_type):
# Replace with your ground truth file path
print(f'[INFO] data {dt} llm {lt}')
llm_file = f'/mnt1/lyc/llava_finetune/eval_output/results_{dt}_{lt}/preds_description.json'
gt_file = data_gt[dt]
main(llm_file, gt_file)
print()
print()
if __name__ == '__main__':
test()