from tqdm import tqdm #from einops import rearrange from PIL import Image from copy import deepcopy from typing import List, Optional, Union from torch import autocast #from torchvision import utils as vutils from utils.util import EditingJsonDataset, EditingSingleImageDataset, plot_images from lr_schedule import WarmupLinearLRSchedule from torch.optim.lr_scheduler import CosineAnnealingLR from torch.utils.tensorboard import SummaryWriter from models.model import RGN from models.utils import visualize_images, read_image_from_url, draw_image_with_bbox_new, Bbox from utils.util2 import compose_text_with_templates, get_augmentations_template #from torchvision.utils import draw_bounding_boxes from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from torchvision import datasets, transforms from engine import * from vis import * import os, jax, cv2, pdb import numpy as np import argparse, torch, inspect import PIL, time, json, datetime import random import torch.nn as nn import torch.nn.functional as F import torch.backends.cudnn as cudnn import utils.misc as misc import torchvision.transforms as T import torch.distributed as dist from tqdm import tqdm from einops import rearrange from PIL import Image from copy import deepcopy from typing import List, Optional, Union from torch import autocast #from torchvision import utils as vutils from utils.util import build_dataset, plot_images from lr_schedule import WarmupLinearLRSchedule from torch.utils.tensorboard import SummaryWriter from models.model import RGN from models.utils import visualize_images, read_image_from_url, draw_image_with_bbox_new, Bbox from utils.util2 import compose_text_with_templates, get_augmentations_template from torchvision.utils import draw_bounding_boxes from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from torchvision import datasets, transforms from engine import * from utils.post_process import get_final_img import random import os, jax, cv2, pdb import numpy as np import argparse, torch, inspect import PIL, time, json, datetime import torch.nn as nn import torch.nn.functional as F import torch.backends.cudnn as cudnn import utils.misc as misc import torchvision.transforms as T import torch.distributed as dist from tqdm import tqdm #from einops import rearrange from PIL import Image from copy import deepcopy from typing import List, Optional, Union from torch import autocast from torchvision import utils as vutils from utils.util import EditingJsonDataset, EditingSingleImageDataset, plot_images from lr_schedule import WarmupLinearLRSchedule from torch.optim.lr_scheduler import CosineAnnealingLR from torch.utils.tensorboard import SummaryWriter from models.model import RGN from models.utils import visualize_images, read_image_from_url, draw_image_with_bbox_new, Bbox from utils.util2 import compose_text_with_templates, get_augmentations_template from torchvision.utils import draw_bounding_boxes from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from torchvision import datasets, transforms from engine import * from vis import * import os, jax, cv2, pdb import numpy as np import argparse, torch, inspect import PIL, time, json, datetime import random import torch.nn as nn import torch.nn.functional as F import torch.backends.cudnn as cudnn import utils.misc as misc import torchvision.transforms as T import torch.distributed as dist def map_cooridates(bbox, min_num=0, max_num=255): # input feat size: 32 x 32 min_num2, max_num2 = 0, 31 return (max_num-min_num)/(max_num2-min_num2) * \ (bbox-min_num2) + min_num def get_mask_imgs(imgs, bboxs): imgs = imgs.repeat_interleave(bboxs.shape[0]//imgs.shape[0], 0) mask_imgs = torch.zeros(imgs.shape, dtype=torch.uint8) for i in range(imgs.shape[0]): mask_imgs[i][:, bboxs[i][1].int().item():bboxs[i][3].int().item(), \ bboxs[i][0].int().item():bboxs[i][2].int().item()] = 1 return imgs, mask_imgs.float() def save_img(args, batch, results, bboxs, imgs, mask_imgs, editing_rompt): transform = T.Resize(512) for i in range(results.shape[0]): img = (imgs[i]*255.0).to(dtype=torch.uint8) bbox = bboxs[i].to(dtype=torch.uint8).unsqueeze(0) draw_img = draw_bounding_boxes(img, bbox, width=3, colors=(255,255,0)) img_name = '-'.join(str(editing_rompt).split(' ')) ori_img_path = os.path.join(new_path, 'input_image.png') if i == 0: save_image(transform(imgs[i]), ori_img_path) save_image(res[i], os.path.join(new_path2, str(batch) + '_' +str(img_name) + 'anchor'+ str(i)+'.png')) if args.draw_box: bbox = bboxs[i].to(dtype=torch.uint8).unsqueeze(0) draw_img = draw_bounding_boxes(img, bbox, width=3, colors=(255,255,0)) draw_img_path = os.path.join(new_path3, str(batch) + '_' + str(img_name) + 'anchor' + str(i)+'_ori_draw.png') save_image(transform((draw_img/255.0).float()), draw_img_path) get_final_img(args, editing_rompt, ori_img_path, new_path2) template = get_augmentations_template() device_id = 'cuda:1' model = RGN(image_size=args.image_size, device=device_id, args=args).to(device_id) # 使用 OpenCV 读取图像 (BGR 格式) image_cv = cv2.imread("images/1.png") image_cv = cv2.cvtColor(image_cv, cv2.COLOR_BGR2RGB) # 转换为 RGB transform = transforms.Compose([ transforms.Resize((224, 224)), # 调整大小 transforms.ToTensor(), # 转换为 PyTorch 张量 transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # 归一化 ]) imgs = transform(image_pil) imgs = imgs.to(device=device_id, non_blocking=True)[0].unsqueeze(0) e_prompt = "Put some birds in the sky and some flowers around the trees" e_prompt = compose_text_with_templates(e_prompt, template) bboxs = torch.ceil(map_cooridates(model.module.get_anchor_box(imgs))) imgs = imgs.repeat_interleave(bboxs.shape[0]//imgs.shape[0], 0) _, mask_imgs = get_mask_imgs(imgs, bboxs) results = model.module.generate_result(imgs, mask_imgs.to(device_id), e_prompt) results.save('ans.png')