|
from tqdm import tqdm |
|
|
|
from PIL import Image |
|
from copy import deepcopy |
|
|
|
from typing import List, Optional, Union |
|
from torch import autocast |
|
|
|
from utils.util import EditingJsonDataset, EditingSingleImageDataset, plot_images |
|
from lr_schedule import WarmupLinearLRSchedule |
|
from torch.optim.lr_scheduler import CosineAnnealingLR |
|
from torch.utils.tensorboard import SummaryWriter |
|
from models.model import RGN |
|
from models.utils import visualize_images, read_image_from_url, draw_image_with_bbox_new, Bbox |
|
from utils.util2 import compose_text_with_templates, get_augmentations_template |
|
|
|
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD |
|
from torchvision import datasets, transforms |
|
from engine import * |
|
from vis import * |
|
import os, jax, cv2, pdb |
|
import numpy as np |
|
import argparse, torch, inspect |
|
import PIL, time, json, datetime |
|
import random |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import torch.backends.cudnn as cudnn |
|
import utils.misc as misc |
|
import torchvision.transforms as T |
|
import torch.distributed as dist |
|
|
|
from tqdm import tqdm |
|
from einops import rearrange |
|
from PIL import Image |
|
from copy import deepcopy |
|
|
|
from typing import List, Optional, Union |
|
from torch import autocast |
|
|
|
from utils.util import build_dataset, plot_images |
|
from lr_schedule import WarmupLinearLRSchedule |
|
from torch.utils.tensorboard import SummaryWriter |
|
from models.model import RGN |
|
from models.utils import visualize_images, read_image_from_url, draw_image_with_bbox_new, Bbox |
|
from utils.util2 import compose_text_with_templates, get_augmentations_template |
|
from torchvision.utils import draw_bounding_boxes |
|
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD |
|
from torchvision import datasets, transforms |
|
from engine import * |
|
from utils.post_process import get_final_img |
|
import random |
|
import os, jax, cv2, pdb |
|
import numpy as np |
|
import argparse, torch, inspect |
|
import PIL, time, json, datetime |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import torch.backends.cudnn as cudnn |
|
import utils.misc as misc |
|
import torchvision.transforms as T |
|
import torch.distributed as dist |
|
from tqdm import tqdm |
|
|
|
from PIL import Image |
|
from copy import deepcopy |
|
|
|
from typing import List, Optional, Union |
|
from torch import autocast |
|
from torchvision import utils as vutils |
|
from utils.util import EditingJsonDataset, EditingSingleImageDataset, plot_images |
|
from lr_schedule import WarmupLinearLRSchedule |
|
from torch.optim.lr_scheduler import CosineAnnealingLR |
|
from torch.utils.tensorboard import SummaryWriter |
|
from models.model import RGN |
|
from models.utils import visualize_images, read_image_from_url, draw_image_with_bbox_new, Bbox |
|
from utils.util2 import compose_text_with_templates, get_augmentations_template |
|
from torchvision.utils import draw_bounding_boxes |
|
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD |
|
from torchvision import datasets, transforms |
|
from engine import * |
|
from vis import * |
|
import os, jax, cv2, pdb |
|
import numpy as np |
|
import argparse, torch, inspect |
|
import PIL, time, json, datetime |
|
import random |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import torch.backends.cudnn as cudnn |
|
import utils.misc as misc |
|
import torchvision.transforms as T |
|
import torch.distributed as dist |
|
|
|
|
|
def map_cooridates(bbox, min_num=0, max_num=255): |
|
|
|
min_num2, max_num2 = 0, 31 |
|
return (max_num-min_num)/(max_num2-min_num2) * \ |
|
(bbox-min_num2) + min_num |
|
|
|
def get_mask_imgs(imgs, bboxs): |
|
imgs = imgs.repeat_interleave(bboxs.shape[0]//imgs.shape[0], 0) |
|
mask_imgs = torch.zeros(imgs.shape, dtype=torch.uint8) |
|
for i in range(imgs.shape[0]): |
|
mask_imgs[i][:, bboxs[i][1].int().item():bboxs[i][3].int().item(), \ |
|
bboxs[i][0].int().item():bboxs[i][2].int().item()] = 1 |
|
return imgs, mask_imgs.float() |
|
|
|
def save_img(args, batch, results, bboxs, imgs, mask_imgs, editing_rompt): |
|
|
|
transform = T.Resize(512) |
|
for i in range(results.shape[0]): |
|
img = (imgs[i]*255.0).to(dtype=torch.uint8) |
|
bbox = bboxs[i].to(dtype=torch.uint8).unsqueeze(0) |
|
draw_img = draw_bounding_boxes(img, bbox, width=3, colors=(255,255,0)) |
|
|
|
img_name = '-'.join(str(editing_rompt).split(' ')) |
|
ori_img_path = os.path.join(new_path, 'input_image.png') |
|
if i == 0: |
|
save_image(transform(imgs[i]), ori_img_path) |
|
save_image(res[i], os.path.join(new_path2, str(batch) + '_' +str(img_name) + 'anchor'+ str(i)+'.png')) |
|
if args.draw_box: |
|
bbox = bboxs[i].to(dtype=torch.uint8).unsqueeze(0) |
|
draw_img = draw_bounding_boxes(img, bbox, width=3, colors=(255,255,0)) |
|
draw_img_path = os.path.join(new_path3, str(batch) + '_' + str(img_name) + 'anchor' + str(i)+'_ori_draw.png') |
|
save_image(transform((draw_img/255.0).float()), draw_img_path) |
|
|
|
get_final_img(args, editing_rompt, ori_img_path, new_path2) |
|
|
|
|
|
template = get_augmentations_template() |
|
device_id = 'cuda:1' |
|
model = RGN(image_size=args.image_size, device=device_id, args=args).to(device_id) |
|
|
|
|
|
|
|
image_cv = cv2.imread("images/1.png") |
|
image_cv = cv2.cvtColor(image_cv, cv2.COLOR_BGR2RGB) |
|
|
|
transform = transforms.Compose([ |
|
transforms.Resize((224, 224)), |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
|
]) |
|
imgs = transform(image_pil) |
|
imgs = imgs.to(device=device_id, non_blocking=True)[0].unsqueeze(0) |
|
e_prompt = "Put some birds in the sky and some flowers around the trees" |
|
e_prompt = compose_text_with_templates(e_prompt, template) |
|
bboxs = torch.ceil(map_cooridates(model.module.get_anchor_box(imgs))) |
|
imgs = imgs.repeat_interleave(bboxs.shape[0]//imgs.shape[0], 0) |
|
_, mask_imgs = get_mask_imgs(imgs, bboxs) |
|
results = model.module.generate_result(imgs, mask_imgs.to(device_id), e_prompt) |
|
results.save('ans.png') |
|
|
|
|
|
|