import itertools |
import json |
import os |
import re |
from collections import namedtuple |
import torch |
from tqdm import tqdm |
class InferenceSampler(torch.utils.data.sampler.Sampler): |
def __init__(self, size): |
self._size = int(size) |
assert size > 0 |
self._rank = torch.distributed.get_rank() |
self._world_size = torch.distributed.get_world_size() |
self._local_indices = self._get_local_indices(size, self._world_size, |
self._rank) |
@staticmethod |
def _get_local_indices(total_size, world_size, rank): |
shard_size = total_size // world_size |
left = total_size % world_size |
shard_sizes = [shard_size + int(r < left) for r in range(world_size)] |
begin = sum(shard_sizes[:rank]) |
end = min(sum(shard_sizes[:rank + 1]), total_size) |
return range(begin, end) |
def __iter__(self): |
yield from self._local_indices |
def __len__(self): |
return len(self._local_indices) |
def collate_fn_vqa(batches): |
''' |
''' |
image_paths = [_['image_path'] for _ in batches] |
questions = [_['question'] for _ in batches] |
gt_answers = [_['gt_answers'] for _ in batches] |
ocr_tokens = [_['ocr_tokens'] if 'ocr_tokens' in _ else None for _ in batches] |
question_ids = [_['question_id'] if 'question_id' in _ else None for _ in batches] |
question_type = [_['question_type'] if 'question_type' in _ else None for _ in batches] |
return image_paths, questions, gt_answers, ocr_tokens, question_ids, question_type |
def has_word(sentence, word): |
if word[0].isalnum(): |
start_pattern = r"\b" |
else: |
start_pattern = r"" |
if word[-1].isalnum(): |
end_pattern = r"\b" |
else: |
end_pattern = r"" |
pattern = start_pattern + re.escape(word) + end_pattern |
match = re.search(pattern, sentence) |
return bool(match) |
def remove_special_chars(s): |
pattern = r"[^a-zA-Z0-9\s]" |
s = re.sub(pattern, "", s) |
return s |
def levenshtein_distance(s1, s2): |
if len(s1) > len(s2): |
s1, s2 = s2, s1 |
distances = range(len(s1) + 1) |
for i2, c2 in enumerate(s2): |
distances_ = [i2+1] |
for i1, c1 in enumerate(s1): |
if c1 == c2: |
distances_.append(distances[i1]) |
else: |
distances_.append(1 + min((distances[i1], distances[i1 + 1], distances_[-1]))) |
distances = distances_ |
return distances[-1] |
class VQAEval: |
def __init__(self): |
self.contractions = { |
"aint": "ain't", |
"arent": "aren't", |
"cant": "can't", |
"couldve": "could've", |
"couldnt": "couldn't", |
"couldn'tve": "couldn't've", |
"couldnt've": "couldn't've", |
"didnt": "didn't", |
"doesnt": "doesn't", |
"dont": "don't", |
"hadnt": "hadn't", |
"hadnt've": "hadn't've", |
"hadn'tve": "hadn't've", |
"hasnt": "hasn't", |
"havent": "haven't", |
"hed": "he'd", |
"hed've": "he'd've", |
"he'dve": "he'd've", |
"hes": "he's", |
"howd": "how'd", |
"howll": "how'll", |
"hows": "how's", |
"Id've": "I'd've", |
"I'dve": "I'd've", |
"Im": "I'm", |
"Ive": "I've", |
"isnt": "isn't", |
"itd": "it'd", |
"itd've": "it'd've", |
"it'dve": "it'd've", |
"itll": "it'll", |
"let's": "let's", |
"maam": "ma'am", |
"mightnt": "mightn't", |
"mightnt've": "mightn't've", |
"mightn'tve": "mightn't've", |
"mightve": "might've", |
"mustnt": "mustn't", |
"mustve": "must've", |
"neednt": "needn't", |
"notve": "not've", |
"oclock": "o'clock", |
"oughtnt": "oughtn't", |
"ow's'at": "'ow's'at", |
"'ows'at": "'ow's'at", |
"'ow'sat": "'ow's'at", |
"shant": "shan't", |
"shed've": "she'd've", |
"she'dve": "she'd've", |
"she's": "she's", |
"shouldve": "should've", |
"shouldnt": "shouldn't", |
"shouldnt've": "shouldn't've", |
"shouldn'tve": "shouldn't've", |
"somebody'd": "somebodyd", |
"somebodyd've": "somebody'd've", |
"somebody'dve": "somebody'd've", |
"somebodyll": "somebody'll", |
"somebodys": "somebody's", |
"someoned": "someone'd", |
"someoned've": "someone'd've", |
"someone'dve": "someone'd've", |
"someonell": "someone'll", |
"someones": "someone's", |
"somethingd": "something'd", |
"somethingd've": "something'd've", |
"something'dve": "something'd've", |
"somethingll": "something'll", |
"thats": "that's", |
"thered": "there'd", |
"thered've": "there'd've", |
"there'dve": "there'd've", |
"therere": "there're", |
"theres": "there's", |
"theyd": "they'd", |
"theyd've": "they'd've", |
"they'dve": "they'd've", |
"theyll": "they'll", |
"theyre": "they're", |
"theyve": "they've", |
"twas": "'twas", |
"wasnt": "wasn't", |
"wed've": "we'd've", |
"we'dve": "we'd've", |
"weve": "we've", |
"werent": "weren't", |
"whatll": "what'll", |
"whatre": "what're", |
"whats": "what's", |
"whatve": "what've", |
"whens": "when's", |
"whered": "where'd", |
"wheres": "where's", |
"whereve": "where've", |
"whod": "who'd", |
"whod've": "who'd've", |
"who'dve": "who'd've", |
"wholl": "who'll", |
"whos": "who's", |
"whove": "who've", |
"whyll": "why'll", |
"whyre": "why're", |
"whys": "why's", |
"wont": "won't", |
"wouldve": "would've", |
"wouldnt": "wouldn't", |
"wouldnt've": "wouldn't've", |
"wouldn'tve": "wouldn't've", |
"yall": "y'all", |
"yall'll": "y'all'll", |
"y'allll": "y'all'll", |
"yall'd've": "y'all'd've", |
"y'alld've": "y'all'd've", |
"y'all'dve": "y'all'd've", |
"youd": "you'd", |
"youd've": "you'd've", |
"you'dve": "you'd've", |
"youll": "you'll", |
"youre": "you're", |
"youve": "you've", |
} |
self.manualMap = { |
"none": "0", |
"zero": "0", |
"one": "1", |
"two": "2", |
"three": "3", |
"four": "4", |
"five": "5", |
"six": "6", |
"seven": "7", |
"eight": "8", |
"nine": "9", |
"ten": "10", |
} |
self.articles = ["a", "an", "the"] |
self.periodStrip = re.compile("(?!<=\d)(\.)(?!\d)") |
self.commaStrip = re.compile("(\d)(\,)(\d)") |
self.punct = [ |
";", |
r"/", |
"[", |
"]", |
'"', |
"{", |
"}", |
"(", |
")", |
"=", |
"+", |
"\\", |
"_", |
"-", |
">", |
"<", |
"@", |
"`", |
",", |
"?", |
"!", |
] |
def clean_text(self, text): |
text = text.replace("\n", " ").replace("\t", " ").strip() |
text = self.processPunctuation(text) |
text = self.processDigitArticle(text) |
return text |
def evaluate_vqa_human(self, answer, gt_answers): |
'''TextVQA, VQAv2, OKVQA, vizwiz''' |
answer = answer.replace("\n", " ").replace("\t", " ").strip() |
answer = self.processPunctuation(answer) |
answer = self.processDigitArticle(answer) |
gt_answers = [self.processPunctuation(ans) for ans in gt_answers] |
gt_answers = [self.processDigitArticle(ans) for ans in gt_answers] |
gtAcc = [] |
for idx, gtAnsDatum in enumerate(gt_answers): |
otherGTAns = gt_answers[:idx] + gt_answers[idx+1:] |
matchingAns = [item for item in otherGTAns if answer == item] |
acc = min(1, float(len(matchingAns)) / 3) |
gtAcc.append(acc) |
avgGTAcc = float(sum(gtAcc)) / len(gtAcc) if gtAcc else 0 |
return avgGTAcc |
def evaluate_anls(self, answer, gt_answers, threshold=0.5): |
'''DOcVQA, InfographicsVQA, STVQA''' |
answer = ' '.join(answer.strip().lower().split()) |
if not isinstance(gt_answers, list): |
gt_answers = [gt_answers] |
gt_answers = [' '.join(gt_answer.strip().lower().split()) for gt_answer in gt_answers] |
values = [] |
for gt_answer in gt_answers: |
dist = levenshtein_distance(answer, gt_answer) |
length = max(len(answer), len(gt_answer)) |
values.append(0.0 if length == 0 else float(dist) / float(length)) |
score = 1 - min(values) |
score = 0 if score < threshold else score |
return score |
def processPunctuation(self, inText): |
outText = inText |
for p in self.punct: |
if (p + " " in inText or " " + p in inText) or ( |
re.search(self.commaStrip, inText) != None |
): |
outText = outText.replace(p, "") |
else: |
outText = outText.replace(p, " ") |
outText = self.periodStrip.sub("", outText, re.UNICODE) |
return outText |
def processDigitArticle(self, inText): |
outText = [] |
tempText = inText.lower().split() |
for word in tempText: |
word = self.manualMap.setdefault(word, word) |
if word not in self.articles: |
outText.append(word) |
else: |
pass |
for wordId, word in enumerate(outText): |
if word in self.contractions: |
outText[wordId] = self.contractions[word] |
outText = " ".join(outText) |
return outText |
def evaluate_dataset(dataset_name, answer_file_path, model_name, method = None): |
with open(answer_file_path, 'r', encoding='utf-8') as f: |
predictions = json.load(f) |
eval = VQAEval() |
total_accuracy = 0 |
num = 0 |
Entry = namedtuple('Entry', ['text', 'bbox']) |
for item in predictions: |
gt_answers = item['gt_answers'] |
answer = item['answer'] |
if method is not None: |
pass |
if dataset_name in ["textVQA"]: |
if num == 0: |
print(f"evaluating vqa...") |
accuracy = eval.evaluate_vqa_human(answer, gt_answers) |
elif dataset_name in ['docVQA']: |
if num == 0: |
print(f"evaluating anls...") |
accuracy = eval.evaluate_anls(answer, gt_answers) |
else: |
accuracy = eval.evaluate_has(answer, gt_answers) |
item['accuracy'] = accuracy |
total_accuracy += accuracy |
num += 1 |
average_accuracy = total_accuracy / num |
print(f'{dataset_name}:{average_accuracy}') |
answer_model_method_path = answer_file_path.replace('.json', f'_{model_name}_{method}.json') |
with open(answer_model_method_path, "w", encoding='utf-8') as f: |
json.dump(predictions, f, indent=4, ensure_ascii=False) |
return average_accuracy |
def evaluate_VQA( |
model, |
dataset, |
model_name, |
dataset_name, |
time, |
batch_size=1, |
generate_method="interleave", |
answer_path='./answers', |
): |
print(f"answer path:{answer_path}") |
sampler = None |
if torch.distributed.is_initialized(): |
sampler=InferenceSampler(len(dataset)) |
dataloader = torch.utils.data.DataLoader( |
dataset=dataset, |
batch_size=batch_size, |
sampler=sampler, |
collate_fn=collate_fn_vqa |
) |
now_rank = torch.distributed.get_rank() |
answer_dir = os.path.join(answer_path, model_name, time) |
os.makedirs(answer_dir, exist_ok=True) |
image_list = [] |
for item in dataset: |
image_list.append(item["image_path"]) |
predictions = [] |
for batch in tqdm(dataloader, desc="Running inference"): |
image_paths, questions, gt_answers, ocr_tokens_list, question_ids, question_type = batch |
with torch.no_grad(): |
if model_name != "minicpm": |
if model_name != "codellama": |
outputs = model.generate(images=image_paths, questions=questions, datasetname=dataset_name) |
else: |
outputs = model.generate() |
elif model_name == "minicpm": |
if generate_method == "old": |
outputs = model.generate(images=image_paths, questions=questions, datasetname=dataset_name) |
elif generate_method == "interleave": |
outputs = model.generate_with_interleaved(images=image_paths, questions=questions, datasetname=dataset_name) |
else: |
raise Exception(f"Wrong generate paradigm {generate_method}!") |
for i in range(len(outputs)): |
answer_dict = { |
'question_id': question_ids[i], |
'question': questions[i], |
'answer': outputs[i], |
'gt_answers': gt_answers[i], |
'image_path': image_paths[i], |
'model_name': model_name, |
'question_type': question_type[i] |
} |
predictions.append(answer_dict) |
if torch.distributed.is_initialized(): |
torch.distributed.barrier() |
if torch.distributed.is_initialized(): |
world_size = torch.distributed.get_world_size() |
merged_predictions = [None for _ in range(world_size)] |
torch.distributed.all_gather_object(merged_predictions, predictions) |
predictions = [_ for _ in itertools.chain.from_iterable(merged_predictions)] |
if torch.distributed.is_initialized() and torch.distributed.get_rank() != 0: |
return None |
answer_file_path = os.path.join(answer_dir, f"{dataset_name}.json") |
print(f"answer_file_path:{answer_file_path}") |
with open(answer_file_path, "w", encoding='utf-8') as f: |
json.dump(predictions, f, indent=4, ensure_ascii=False) |
if dataset_name in ["docVQATest"]: |
return -1.0 |
return evaluate_dataset(answer_file_path=answer_file_path, dataset_name=dataset_name, model_name=model_name) |