|
import re |
|
|
|
from .. import BaseComputeMetrics |
|
from ..root import ( |
|
DATASETS, |
|
METRICS, |
|
QUESTION_PLACEHOLDER, |
|
IMAGE_PLACEHOLDER, |
|
BOXES_PLACEHOLDER, |
|
POINTS_PLACEHOLDER, |
|
) |
|
from ..utils import MInstrDataset |
|
|
|
|
|
|
|
@DATASETS.register_module() |
|
class Point_QA_local(MInstrDataset): |
|
def __init__(self, *args, version='p', qbp_p_prob=0.5, **kwargs): |
|
super().__init__(*args, **kwargs, placeholders=(IMAGE_PLACEHOLDER, QUESTION_PLACEHOLDER)) |
|
assert version in ['b', 'p', 'bp'] |
|
self.version = version |
|
self.qbp_p_prob = qbp_p_prob |
|
|
|
def __getitem__(self, index): |
|
item = self.get_raw_item(index) |
|
|
|
img_path = item['file_path'] |
|
image = self.get_image(img_path) |
|
|
|
answer = item['answer'] |
|
|
|
question = item['question'] |
|
bbox = item['bbox'] |
|
point = item['point'] |
|
|
|
version = self.version |
|
if version == 'bp': |
|
version = 'p' if self.rng.random() < self.qbp_p_prob else 'b' |
|
if version == 'b': |
|
question = question + BOXES_PLACEHOLDER |
|
query_boxes_seq = [[0]] |
|
query_points_seq = None |
|
elif version == 'p': |
|
question = question + POINTS_PLACEHOLDER |
|
query_boxes_seq = None |
|
query_points_seq = [[0]] |
|
else: |
|
assert False |
|
final_question = self.get_template().replace(QUESTION_PLACEHOLDER, question) |
|
|
|
ret = { |
|
'image': image, |
|
'target': { |
|
'boxes': [bbox], |
|
'points': [point], |
|
}, |
|
'conversations': [ |
|
{ |
|
'from': 'human', |
|
'value': final_question, |
|
'boxes_seq': query_boxes_seq, |
|
'points_seq': query_points_seq, |
|
}, |
|
{ |
|
'from': 'gpt', |
|
'value': f'The answer is {answer} .', |
|
} |
|
] |
|
} |
|
return ret |
|
|
|
|
|
|
|
@DATASETS.register_module() |
|
class Point_QA_twice(MInstrDataset): |
|
def __init__(self, *args, version='gq-p', bp_p_prob=0.5, **kwargs): |
|
super().__init__(*args, **kwargs, placeholders=(IMAGE_PLACEHOLDER, QUESTION_PLACEHOLDER)) |
|
self.version = version |
|
self.bp_p_prob = bp_p_prob |
|
qtype, rtype = version.split('-') |
|
assert qtype in ['oq', 'sq', 'gq'] |
|
assert rtype in ['b', 'p', 'bp'] |
|
self.qtype = qtype |
|
self.rtype = rtype |
|
|
|
def __getitem__(self, index): |
|
item = self.get_raw_item(index) |
|
|
|
img_path = item['file_path'] |
|
image = self.get_image(img_path) |
|
|
|
answer = item['answer'] |
|
|
|
bbox = item['bbox'] |
|
point = item['point'] |
|
if self.qtype == 'oq': |
|
question = item['obj_question'] |
|
elif self.qtype == 'sq': |
|
question = item['super_question'] |
|
elif self.qtype == 'gq': |
|
question = item['general_question'] |
|
else: |
|
assert False |
|
rtype = self.rtype |
|
if rtype == 'bp': |
|
rtype = 'p' if self.rng.random() < self.bp_p_prob else 'b' |
|
if rtype == 'p': |
|
question = question + POINTS_PLACEHOLDER |
|
query_boxes_seq = None |
|
query_points_seq = [[0]] |
|
elif rtype == 'b': |
|
question = question + BOXES_PLACEHOLDER |
|
query_boxes_seq = [[0]] |
|
query_points_seq = None |
|
else: |
|
assert False |
|
final_question = self.get_template().replace(QUESTION_PLACEHOLDER, question) |
|
|
|
ret = { |
|
'image': image, |
|
'target': { |
|
'boxes': [bbox], |
|
'points': [point], |
|
}, |
|
'conversations': [ |
|
{ |
|
'from': 'human', |
|
'value': final_question, |
|
'boxes_seq': query_boxes_seq, |
|
'points_seq': query_points_seq, |
|
}, |
|
{ |
|
'from': 'gpt', |
|
'value': f'The answer is {answer} .', |
|
} |
|
] |
|
} |
|
return ret |
|
|
|
|
|
|
|
@DATASETS.register_module() |
|
class V7W_POINT(MInstrDataset): |
|
def __init__(self, *args, version, do_shuffle_choice=True, **kwargs): |
|
super().__init__(*args, **kwargs, placeholders=(IMAGE_PLACEHOLDER, QUESTION_PLACEHOLDER)) |
|
self.version = version |
|
self.do_shuffle_choice = do_shuffle_choice |
|
assert version in ['p', 'b'] |
|
|
|
def __len__(self): |
|
return len(self.data) |
|
|
|
def __getitem__(self, index): |
|
item = self.get_raw_item(index) |
|
|
|
img_path = item['file_path'] |
|
image = self.get_image(img_path) |
|
|
|
bboxes = item['candidates'] |
|
points = [] |
|
final_question = item['question'] + ' Candidates: ' + " ".join([BOXES_PLACEHOLDER for _ in range(len(bboxes))]) |
|
query_boxes_seq = [] |
|
for _ in range(len(bboxes)): |
|
query_boxes_seq.append([_]) |
|
|
|
if self.version == 'p': |
|
final_question += f" answer in point format." |
|
points.append(item['point']) |
|
final_answer = f"The answer is {POINTS_PLACEHOLDER} ." |
|
answer_boxes_seq = None |
|
answer_points_seq = [[0]] |
|
elif self.version == 'b': |
|
final_question += f" answer in box format." |
|
idx = bboxes.index(item['answer']) |
|
final_answer = f"The answer is {BOXES_PLACEHOLDER} ." |
|
answer_boxes_seq = [[idx]] |
|
answer_points_seq = None |
|
else: |
|
assert False |
|
final_question = self.get_template().replace(QUESTION_PLACEHOLDER, final_question) |
|
if self.do_shuffle_choice: |
|
self.rng.shuffle(query_boxes_seq) |
|
|
|
|
|
ret = { |
|
'image': image, |
|
'target': { |
|
'boxes': bboxes, |
|
'points': points, |
|
}, |
|
'conversations': [ |
|
{ |
|
'from': 'human', |
|
'value': final_question, |
|
'boxes_seq': query_boxes_seq, |
|
}, |
|
{ |
|
'from': 'gpt', |
|
'value': final_answer, |
|
'boxes_seq': answer_boxes_seq, |
|
'points_seq': answer_points_seq, |
|
|
|
} |
|
] |
|
} |
|
return ret |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ANS_EXTRACT_PAT = re.compile(r'(?:The answer is (.+?)\.)') |
|
|
|
|
|
@METRICS.register_module() |
|
class PointQAComputeMetrics(BaseComputeMetrics): |
|
def extract_ans(self, string: str): |
|
try: |
|
found = ANS_EXTRACT_PAT.findall(string.strip()) |
|
if len(found) != 1: |
|
return None |
|
return found[0].strip() |
|
except (IndexError, AttributeError): |
|
return None |
|
|