|
image_path = './demo/demo.jpg' |
|
sentence = 'the most handsome guy' |
|
weights = './checkpoints/refcoco.pth' |
|
device = 'cuda:0' |
|
|
|
|
|
from PIL import Image |
|
import torchvision.transforms as T |
|
import numpy as np |
|
img = Image.open(image_path).convert("RGB") |
|
img_ndarray = np.array(img) |
|
original_w, original_h = img.size |
|
|
|
image_transforms = T.Compose( |
|
[ |
|
T.Resize(480), |
|
T.ToTensor(), |
|
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
|
] |
|
) |
|
|
|
img = image_transforms(img).unsqueeze(0) |
|
img = img.to(device) |
|
|
|
|
|
from bert.tokenization_bert import BertTokenizer |
|
import torch |
|
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') |
|
sentence_tokenized = tokenizer.encode(text=sentence, add_special_tokens=True) |
|
sentence_tokenized = sentence_tokenized[:20] |
|
|
|
padded_sent_toks = [0] * 20 |
|
padded_sent_toks[:len(sentence_tokenized)] = sentence_tokenized |
|
|
|
attention_mask = [0] * 20 |
|
attention_mask[:len(sentence_tokenized)] = [1]*len(sentence_tokenized) |
|
|
|
padded_sent_toks = torch.tensor(padded_sent_toks).unsqueeze(0) |
|
attention_mask = torch.tensor(attention_mask).unsqueeze(0) |
|
padded_sent_toks = padded_sent_toks.to(device) |
|
attention_mask = attention_mask.to(device) |
|
|
|
|
|
from bert.modeling_bert import BertModel |
|
from lib import segmentation |
|
|
|
|
|
|
|
|
|
class args: |
|
swin_type = 'base' |
|
window12 = True |
|
mha = '' |
|
fusion_drop = 0.0 |
|
|
|
|
|
single_model = segmentation.__dict__['lavt'](pretrained='', args=args) |
|
single_model.to(device) |
|
model_class = BertModel |
|
single_bert_model = model_class.from_pretrained('bert-base-uncased') |
|
single_bert_model.pooler = None |
|
|
|
checkpoint = torch.load(weights, map_location='cpu') |
|
single_bert_model.load_state_dict(checkpoint['bert_model']) |
|
single_model.load_state_dict(checkpoint['model']) |
|
model = single_model.to(device) |
|
bert_model = single_bert_model.to(device) |
|
|
|
|
|
|
|
import torch.nn.functional as F |
|
last_hidden_states = bert_model(padded_sent_toks, attention_mask=attention_mask)[0] |
|
embedding = last_hidden_states.permute(0, 2, 1) |
|
output = model(img, embedding, l_mask=attention_mask.unsqueeze(-1)) |
|
output = output.argmax(1, keepdim=True) |
|
output = F.interpolate(output.float(), (original_h, original_w)) |
|
output = output.squeeze() |
|
output = output.cpu().data.numpy() |
|
|
|
|
|
|
|
def overlay_davis(image, mask, colors=[[0, 0, 0], [255, 0, 0]], cscale=1, alpha=0.4): |
|
from scipy.ndimage.morphology import binary_dilation |
|
|
|
colors = np.reshape(colors, (-1, 3)) |
|
colors = np.atleast_2d(colors) * cscale |
|
|
|
im_overlay = image.copy() |
|
object_ids = np.unique(mask) |
|
|
|
for object_id in object_ids[1:]: |
|
|
|
foreground = image*alpha + np.ones(image.shape)*(1-alpha) * np.array(colors[object_id]) |
|
binary_mask = mask == object_id |
|
|
|
|
|
im_overlay[binary_mask] = foreground[binary_mask] |
|
|
|
|
|
countours = binary_dilation(binary_mask) ^ binary_mask |
|
|
|
im_overlay[countours, :] = 0 |
|
|
|
return im_overlay.astype(image.dtype) |
|
|
|
|
|
output = output.astype(np.uint8) |
|
|
|
visualization = overlay_davis(img_ndarray, output) |
|
visualization = Image.fromarray(visualization) |
|
|
|
|
|
|
|
visualization.save('./demo/demo_result.jpg') |
|
|
|
|
|
|
|
|
|
|