|
from models import IntuitionKillingMachine |
|
from transforms import undo_box_transforms_batch, ToTensor, Normalize, SquarePad, Resize, NormalizeBoxCoords |
|
from torchvision.transforms import Compose |
|
from encoders import get_tokenizer |
|
from PIL import Image, ImageDraw |
|
from zipfile import ZipFile |
|
from copy import copy |
|
import pandas as pd |
|
import torch |
|
|
|
def parse_model_args(model_path): |
|
_, _, dataset, max_length, input_size, backbone, num_heads, num_layers, num_conv, _, _, mu, mask_pooling = model_path.split('_')[:13] |
|
return { |
|
'dataset': dataset, |
|
'max_length': int(max_length), |
|
'input_size': int(input_size), |
|
'backbone': backbone, |
|
'num_heads': int(num_heads), |
|
'num_layers': int(num_layers), |
|
'num_conv': int(num_conv), |
|
'mu': float(mu), |
|
'mask_pooling': bool(mask_pooling == '1') |
|
} |
|
|
|
|
|
class Prober: |
|
def __init__(self, |
|
df_path=None, |
|
dataset_path=None, |
|
model_checkpoint=None): |
|
params = parse_model_args(model_checkpoint) |
|
mean = [0.485, 0.456, 0.406] |
|
sdev = [0.229, 0.224, 0.225] |
|
self.tokenizer = get_tokenizer() |
|
self.df = pd.read_json(df_path)[['sample_idx', 'bbox', 'file_path', 'sent']] |
|
self.df.loc[:, "image_id"] = self.df.file_path.apply(lambda x: int(x.split('/')[-1][:-4])) |
|
self.df.file_path = self.df.file_path.apply(lambda x: x.replace('refer/data/images/', '')) |
|
self.model = IntuitionKillingMachine( |
|
backbone=params['backbone'], |
|
pretrained=True, |
|
num_heads=params['num_heads'], |
|
num_layers=params['num_layers'], |
|
num_conv=params['num_conv'], |
|
segmentation_head=bool(params['mu'] > 0.0), |
|
mask_pooling=params['mask_pooling'] |
|
) |
|
self.transform = Compose([ |
|
ToTensor(), |
|
Normalize(mean, sdev), |
|
SquarePad(), |
|
Resize(size=(params['input_size'], params['input_size'])), |
|
NormalizeBoxCoords(), |
|
]) |
|
self.max_length = 30 |
|
self.zipfile = ZipFile(dataset_path, 'r') |
|
|
|
@torch.no_grad() |
|
def probe(self, idx, re, search_by_sample_id: bool= True): |
|
if search_by_sample_id: |
|
img_path, target, = self.df.loc[idx][['file_path','bbox']].values |
|
else: |
|
img_path, target = self.df[self.df.image_id == idx][['file_path','bbox']].values[0] |
|
img = Image.open(self.zipfile.open(img_path)).convert('RGB') |
|
W0, H0 = img.size |
|
sample = { |
|
'image': img, |
|
'image_size': (H0, W0), |
|
'bbox': torch.tensor([copy(target)]), |
|
'bbox_raw': torch.tensor([copy(target)]), |
|
'mask': torch.ones((1, H0, W0), dtype=torch.float32), |
|
'mask_bbox': None, |
|
} |
|
print('inn bbox: ', sample['bbox']) |
|
sample = self.transform(sample) |
|
tok = self.tokenizer(re, |
|
max_length=30, |
|
return_tensors='pt', |
|
truncation=True) |
|
inn = {'image': torch.stack([sample['image']]), |
|
'mask': torch.stack([sample['mask']]), |
|
'bbox': torch.stack([sample['bbox']]), |
|
'tok': tok} |
|
output = undo_box_transforms_batch(self.model(inn)[0], |
|
[sample['tr_param']]).numpy().tolist()[0] |
|
img1 = ImageDraw.Draw(img) |
|
|
|
img1.rectangle(output, outline ="#00FF0000", width=3) |
|
return img |
|
|
|
if __name__ == "__main__": |
|
prober = Prober( |
|
df_path = 'data/val-sim_metric.json', |
|
dataset_path = "data/saiapr_tc-12.zip", |
|
model_checkpoint= "cache/20211220_191132_refclef_32_512_resnet50_8_6_8_0.1_0.1_0.1_0_0.0001_0.0_12_4_90_1_0_0_0/best.ckpt" |
|
) |
|
prober.probe(0, "tree") |
|
print("Done") |