|
import json |
|
import os |
|
import re |
|
from torch.utils.data import Dataset |
|
|
|
def prompt_processor(prompt): |
|
if prompt.startswith('OCR tokens: '): |
|
pattern = r"Question: (.*?) Short answer:" |
|
match = re.search(pattern, prompt, re.DOTALL) |
|
question = match.group(1) |
|
elif 'Reference OCR token: ' in prompt and len(prompt.split('\n')) == 3: |
|
if prompt.startswith('Reference OCR token:'): |
|
question = prompt.split('\n')[1] |
|
else: |
|
question = prompt.split('\n')[0] |
|
elif len(prompt.split('\n')) == 2: |
|
question = prompt.split('\n')[0] |
|
else: |
|
assert False |
|
|
|
return question.lower() |
|
|
|
class textVQADataset(Dataset): |
|
def __init__( |
|
self, |
|
image_dir="./downloads/TextVQA/train_images", |
|
ann_path="./downloads/TextVQA/TextVQA_0.5.1_val.json", |
|
): |
|
self.data = json.load(open(ann_path, "r"))["data"] |
|
self.image_dir = image_dir |
|
|
|
def __len__(self): |
|
return len(self.data) |
|
|
|
def __getitem__(self, idx): |
|
question = self.data[idx]['question'] |
|
answers = self.data[idx]['answers'] |
|
img_id = self.data[idx]['image_id'] |
|
qid = self.data[idx]['question_id'] |
|
img_path = os.path.join(self.image_dir, f"{img_id}.jpg") |
|
|
|
item = { |
|
"question_id": qid, |
|
"image_path": img_path, |
|
"question": question, |
|
"gt_answers": answers |
|
} |
|
|
|
return item |
|
|
|
class docVQADataset(Dataset): |
|
def __init__( |
|
self, |
|
image_dir= "./downloads/DocVQA/spdocvqa_images", |
|
ann_path= "./downloads/DocVQA/val_v1.0_withQT.json", |
|
ocr_token_path=None |
|
): |
|
|
|
self.data = json.load(open(ann_path, "r"))["data"] |
|
self.image_dir = image_dir |
|
self.ann_path = ann_path |
|
if ocr_token_path: |
|
self.ocr_token_data = {item['image_id']: item for item in json.load(open(ocr_token_path, "r"))["data"]} |
|
|
|
def __len__(self): |
|
return len(self.data) |
|
|
|
def __getitem__(self, idx): |
|
question_id = self.data[idx]['questionId'] |
|
relative_img_path = self.data[idx]['image'] |
|
corrected_relative_img_path = relative_img_path.replace("documents", "images") |
|
img_path = os.path.join(self.image_dir, corrected_relative_img_path) |
|
question = self.data[idx]['question'] |
|
answers = self.data[idx]['answers'] |
|
|
|
question_type = self.data[idx]['question_types'] |
|
|
|
return { |
|
"question_id": question_id, |
|
"image_path": img_path, |
|
"question": question, |
|
"gt_answers": answers, |
|
'question_type': question_type, |
|
} |
|
|
|
|
|
class docVQATESTDataset(Dataset): |
|
def __init__( |
|
self, |
|
image_dir= "./downloads/DocVQA/spdocvqa_images", |
|
ann_path= "./downloads/DocVQA/test_v1.0.json", |
|
ocr_token_path=None |
|
): |
|
|
|
self.data = json.load(open(ann_path, "r"))["data"] |
|
self.image_dir = image_dir |
|
self.ann_path = ann_path |
|
|
|
def __len__(self): |
|
return len(self.data) |
|
|
|
def __getitem__(self, idx): |
|
question_id = self.data[idx]['questionId'] |
|
relative_img_path = self.data[idx]['image'] |
|
corrected_relative_img_path = relative_img_path.replace("documents", "images") |
|
img_path = os.path.join(self.image_dir, corrected_relative_img_path) |
|
question = self.data[idx]['question'] |
|
|
|
|
|
return { |
|
"question_id": question_id, |
|
"image_path": img_path, |
|
"question": question, |
|
"gt_answers": "", |
|
'question_type': "", |
|
} |
|
|