File size: 4,123 Bytes
3e1d9f3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 |
import sys
import logging
import warnings
from typing import Dict, Any, Sequence
import torch
from torchvision.ops import box_iou
from ..utils import (
MInstrDataset,
BaseComputeMetrics,
)
from ..process_function import (
BoxFormatter,
)
from ..root import (
DATASETS,
METRICS,
IMAGE_PLACEHOLDER,
BOXES_PLACEHOLDER,
EXPR_PLACEHOLDER,
)
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
handlers=[logging.StreamHandler(sys.stdout), ],
)
@DATASETS.register_module()
class RECDataset(MInstrDataset):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs, placeholders=(IMAGE_PLACEHOLDER, EXPR_PLACEHOLDER))
def __getitem__(self, index):
item = self.get_raw_item(index)
img_path = item['img_path']
expr = item['expression']
bbox = item['bbox']
image = self.get_image(img_path)
question = self.get_template().replace(EXPR_PLACEHOLDER, expr)
ret = {
'image': image,
'target': {
'boxes': [bbox],
},
'conversations': [
{
'from': 'human',
'value': question,
},
{
'from': 'gpt',
'value': f'Answer: {BOXES_PLACEHOLDER} .',
'boxes_seq': [[0]],
}
]
}
return ret
@METRICS.register_module()
class RECComputeMetrics(BaseComputeMetrics):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.box_formatter: BoxFormatter = self.preprocessor['target']['boxes']
def calculate_metric(self, preds: Sequence[str], targets: Sequence[str]) -> Dict[str, Any]:
failed = 0
target_failed = 0
pred_boxes, target_boxes = [], []
for pred, target in zip(preds, targets):
extract_pred = self.extract_ans(pred)
extract_target = self.extract_ans(target)
if extract_target is None:
target_failed += 1
logger.warning(f"failed to extract ans for target: {target}")
continue
if extract_pred is None:
failed += 1
logger.warning(f"failed to extract ans for pred: {pred}")
extract_pred = [0, 0, 0, 0]
target_boxes.append(extract_target)
pred_boxes.append(extract_pred)
with torch.no_grad():
target_boxes = torch.tensor(target_boxes)
pred_boxes = torch.tensor(pred_boxes)
# normalized box value is too small, so that the area is 0.
ious = box_iou(pred_boxes * 1000, target_boxes * 1000)
ious = torch.einsum('i i -> i', ious) # take diag elem
# NOTE: please note iou only calculate for success target
iou = ious.mean().item()
correct = (ious > 0.5).sum().item()
# HACK: currently we expand image to square. so this iou is the real iou.
warn_message = "this iou is calculate on normalized box. just for non-rigorous training progress checking." \
"the value is consistent with real iou only if image.width == image.height."
warnings.warn(warn_message)
return {
'accuracy': 1.0 * correct / len(targets),
'target_failed': target_failed,
'failed': failed,
'iou': iou,
'warning': warn_message,
}
def extract_ans(self, string: str):
try:
list_of_boxes = self.box_formatter.extract(string)
if len(list_of_boxes) != 1 or len(list_of_boxes[0]) != 1:
return None
box = list_of_boxes[0][0]
if len(box) != 4:
return None
return box
except Exception as e:
logger.warning(f"extract_ans for {string} but get exception: {e}")
return None
|