CYF200127's picture
Upload 235 files
3e1d9f3 verified
raw
history blame
8.16 kB
import re
from .. import BaseComputeMetrics
from ..root import (
DATASETS,
METRICS,
QUESTION_PLACEHOLDER,
IMAGE_PLACEHOLDER,
BOXES_PLACEHOLDER,
POINTS_PLACEHOLDER,
)
from ..utils import MInstrDataset
# noinspection PyPep8Naming
@DATASETS.register_module()
class Point_QA_local(MInstrDataset):
def __init__(self, *args, version='p', qbp_p_prob=0.5, **kwargs):
super().__init__(*args, **kwargs, placeholders=(IMAGE_PLACEHOLDER, QUESTION_PLACEHOLDER))
assert version in ['b', 'p', 'bp']
self.version = version
self.qbp_p_prob = qbp_p_prob
def __getitem__(self, index):
item = self.get_raw_item(index)
# image
img_path = item['file_path']
image = self.get_image(img_path)
# answer
answer = item['answer']
# question
question = item['question']
bbox = item['bbox']
point = item['point']
version = self.version
if version == 'bp':
version = 'p' if self.rng.random() < self.qbp_p_prob else 'b'
if version == 'b':
question = question + BOXES_PLACEHOLDER
query_boxes_seq = [[0]]
query_points_seq = None
elif version == 'p':
question = question + POINTS_PLACEHOLDER
query_boxes_seq = None
query_points_seq = [[0]]
else:
assert False
final_question = self.get_template().replace(QUESTION_PLACEHOLDER, question)
ret = {
'image': image,
'target': {
'boxes': [bbox],
'points': [point],
},
'conversations': [
{
'from': 'human',
'value': final_question,
'boxes_seq': query_boxes_seq,
'points_seq': query_points_seq,
},
{
'from': 'gpt',
'value': f'The answer is {answer} .',
}
]
}
return ret
# noinspection PyPep8Naming
@DATASETS.register_module()
class Point_QA_twice(MInstrDataset):
def __init__(self, *args, version='gq-p', bp_p_prob=0.5, **kwargs):
super().__init__(*args, **kwargs, placeholders=(IMAGE_PLACEHOLDER, QUESTION_PLACEHOLDER))
self.version = version
self.bp_p_prob = bp_p_prob
qtype, rtype = version.split('-')
assert qtype in ['oq', 'sq', 'gq']
assert rtype in ['b', 'p', 'bp']
self.qtype = qtype
self.rtype = rtype
def __getitem__(self, index):
item = self.get_raw_item(index)
# image
img_path = item['file_path']
image = self.get_image(img_path)
# answer
answer = item['answer']
# question
bbox = item['bbox']
point = item['point']
if self.qtype == 'oq':
question = item['obj_question']
elif self.qtype == 'sq':
question = item['super_question']
elif self.qtype == 'gq':
question = item['general_question']
else:
assert False
rtype = self.rtype
if rtype == 'bp':
rtype = 'p' if self.rng.random() < self.bp_p_prob else 'b'
if rtype == 'p':
question = question + POINTS_PLACEHOLDER
query_boxes_seq = None
query_points_seq = [[0]]
elif rtype == 'b':
question = question + BOXES_PLACEHOLDER
query_boxes_seq = [[0]]
query_points_seq = None
else:
assert False
final_question = self.get_template().replace(QUESTION_PLACEHOLDER, question)
ret = {
'image': image,
'target': {
'boxes': [bbox],
'points': [point],
},
'conversations': [
{
'from': 'human',
'value': final_question,
'boxes_seq': query_boxes_seq,
'points_seq': query_points_seq,
},
{
'from': 'gpt',
'value': f'The answer is {answer} .',
}
]
}
return ret
# noinspection PyPep8Naming
@DATASETS.register_module()
class V7W_POINT(MInstrDataset):
def __init__(self, *args, version, do_shuffle_choice=True, **kwargs):
super().__init__(*args, **kwargs, placeholders=(IMAGE_PLACEHOLDER, QUESTION_PLACEHOLDER))
self.version = version
self.do_shuffle_choice = do_shuffle_choice
assert version in ['p', 'b']
def __len__(self):
return len(self.data)
def __getitem__(self, index):
item = self.get_raw_item(index)
# image
img_path = item['file_path']
image = self.get_image(img_path)
# question
bboxes = item['candidates']
points = []
final_question = item['question'] + ' Candidates: ' + " ".join([BOXES_PLACEHOLDER for _ in range(len(bboxes))])
query_boxes_seq = []
for _ in range(len(bboxes)):
query_boxes_seq.append([_])
# answer
if self.version == 'p':
final_question += f" answer in point format."
points.append(item['point'])
final_answer = f"The answer is {POINTS_PLACEHOLDER} ."
answer_boxes_seq = None
answer_points_seq = [[0]]
elif self.version == 'b':
final_question += f" answer in box format."
idx = bboxes.index(item['answer'])
final_answer = f"The answer is {BOXES_PLACEHOLDER} ."
answer_boxes_seq = [[idx]]
answer_points_seq = None
else:
assert False
final_question = self.get_template().replace(QUESTION_PLACEHOLDER, final_question)
if self.do_shuffle_choice:
self.rng.shuffle(query_boxes_seq)
# bboxes, query_boxes_seq, answer_boxes_seq = self.shuffle_boxes(bboxes, query_boxes_seq, answer_boxes_seq)
ret = {
'image': image,
'target': {
'boxes': bboxes,
'points': points,
},
'conversations': [
{
'from': 'human',
'value': final_question,
'boxes_seq': query_boxes_seq,
},
{
'from': 'gpt',
'value': final_answer,
'boxes_seq': answer_boxes_seq,
'points_seq': answer_points_seq,
}
]
}
return ret
# def shuffle_boxes(self, bboxes, query_boxes_seq, answer_boxes_seq):
# idx_mapping = list(range(len(bboxes)))
# self.rng.shuffle(idx_mapping)
#
# new_bboxes = [None for _ in range(len(bboxes))]
# for idx_old, idx_new in enumerate(idx_mapping):
# new_bboxes[idx_new] = bboxes[idx_old]
#
# if query_boxes_seq is None:
# new_query_boxes_seq = None
# else:
# new_query_boxes_seq = []
# for boxes in query_boxes_seq:
# new_boxes = [idx_mapping[box_idx] for box_idx in boxes]
# new_query_boxes_seq.append(new_boxes)
#
# if answer_boxes_seq is None:
# new_answer_boxes_seq = None
# else:
# new_answer_boxes_seq = []
# for boxes in answer_boxes_seq:
# new_boxes = [idx_mapping[box_idx] for box_idx in boxes]
# new_answer_boxes_seq.append(new_boxes)
#
# return new_bboxes, new_query_boxes_seq, new_answer_boxes_seq
ANS_EXTRACT_PAT = re.compile(r'(?:The answer is (.+?)\.)')
@METRICS.register_module()
class PointQAComputeMetrics(BaseComputeMetrics):
def extract_ans(self, string: str):
try:
found = ANS_EXTRACT_PAT.findall(string.strip())
if len(found) != 1:
return None
return found[0].strip()
except (IndexError, AttributeError):
return None